Refactor prompt templates and RAG module
Browse files- notebooks/kaggle_medic_demo.ipynb +6 -54
- src/agents.py +46 -217
- src/config.py +15 -71
- src/db/import_data.py +113 -164
- src/graph.py +36 -205
- src/loader.py +20 -92
- src/prompts.py +7 -43
- src/rag.py +68 -295
- src/state.py +15 -47
- src/tools/rag_tools.py +0 -1
- src/utils.py +67 -216
notebooks/kaggle_medic_demo.ipynb
CHANGED
|
@@ -72,15 +72,7 @@
|
|
| 72 |
"id": "205d4ba2",
|
| 73 |
"metadata": {},
|
| 74 |
"outputs": [],
|
| 75 |
-
"source":
|
| 76 |
-
"%%capture\n",
|
| 77 |
-
"!pip install -q \\\n",
|
| 78 |
-
" \"langgraph>=0.0.15\" \"langchain>=0.3.0\" langchain-text-splitters langchain-community \\\n",
|
| 79 |
-
" \"chromadb>=0.4.0\" sentence-transformers \\\n",
|
| 80 |
-
" \"transformers>=4.50.0\" accelerate bitsandbytes \\\n",
|
| 81 |
-
" streamlit huggingface_hub \\\n",
|
| 82 |
-
" \"pydantic>=2.0\" python-dotenv openpyxl pypdf \"pandas>=2.0\" jq"
|
| 83 |
-
]
|
| 84 |
},
|
| 85 |
{
|
| 86 |
"cell_type": "markdown",
|
|
@@ -267,9 +259,7 @@
|
|
| 267 |
"cell_type": "markdown",
|
| 268 |
"id": "37d17f6b",
|
| 269 |
"metadata": {},
|
| 270 |
-
"source":
|
| 271 |
-
"## 5 · Launch the App"
|
| 272 |
-
]
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
|
@@ -277,10 +267,7 @@
|
|
| 277 |
"id": "96ff2d63",
|
| 278 |
"metadata": {},
|
| 279 |
"outputs": [],
|
| 280 |
-
"source": [
|
| 281 |
-
"%%capture\n",
|
| 282 |
-
"!pip install -q localtunnel"
|
| 283 |
-
]
|
| 284 |
},
|
| 285 |
{
|
| 286 |
"cell_type": "code",
|
|
@@ -288,28 +275,7 @@
|
|
| 288 |
"id": "ea6b1788",
|
| 289 |
"metadata": {},
|
| 290 |
"outputs": [],
|
| 291 |
-
"source": [
|
| 292 |
-
"import subprocess, time, requests\n",
|
| 293 |
-
"\n",
|
| 294 |
-
"streamlit_proc = subprocess.Popen(\n",
|
| 295 |
-
" [\"streamlit\", \"run\", \"/kaggle/working/Med-I-C/app.py\",\n",
|
| 296 |
-
" \"--server.port\", \"8501\",\n",
|
| 297 |
-
" \"--server.headless\", \"true\",\n",
|
| 298 |
-
" \"--server.enableCORS\", \"false\"],\n",
|
| 299 |
-
" stdout=subprocess.DEVNULL,\n",
|
| 300 |
-
" stderr=subprocess.DEVNULL,\n",
|
| 301 |
-
")\n",
|
| 302 |
-
"\n",
|
| 303 |
-
"for _ in range(15):\n",
|
| 304 |
-
" try:\n",
|
| 305 |
-
" if requests.get(\"http://localhost:8501\", timeout=2).status_code == 200:\n",
|
| 306 |
-
" print(\"Streamlit running on :8501\")\n",
|
| 307 |
-
" break\n",
|
| 308 |
-
" except Exception:\n",
|
| 309 |
-
" time.sleep(2)\n",
|
| 310 |
-
"else:\n",
|
| 311 |
-
" print(\"Streamlit may still be starting…\")"
|
| 312 |
-
]
|
| 313 |
},
|
| 314 |
{
|
| 315 |
"cell_type": "code",
|
|
@@ -317,21 +283,7 @@
|
|
| 317 |
"id": "00ecfb17",
|
| 318 |
"metadata": {},
|
| 319 |
"outputs": [],
|
| 320 |
-
"source":
|
| 321 |
-
"tunnel_proc = subprocess.Popen(\n",
|
| 322 |
-
" [\"npx\", \"localtunnel\", \"--port\", \"8501\"],\n",
|
| 323 |
-
" stdout=subprocess.PIPE,\n",
|
| 324 |
-
" stderr=subprocess.DEVNULL,\n",
|
| 325 |
-
" text=True,\n",
|
| 326 |
-
")\n",
|
| 327 |
-
"\n",
|
| 328 |
-
"for line in tunnel_proc.stdout:\n",
|
| 329 |
-
" if \"https://\" in line:\n",
|
| 330 |
-
" print(\"\\n\" + \"=\"*50)\n",
|
| 331 |
-
" print(f\" App URL: {line.strip()}\")\n",
|
| 332 |
-
" print(\"=\"*50)\n",
|
| 333 |
-
" break"
|
| 334 |
-
]
|
| 335 |
}
|
| 336 |
],
|
| 337 |
"metadata": {
|
|
@@ -347,4 +299,4 @@
|
|
| 347 |
},
|
| 348 |
"nbformat": 4,
|
| 349 |
"nbformat_minor": 5
|
| 350 |
-
}
|
|
|
|
| 72 |
"id": "205d4ba2",
|
| 73 |
"metadata": {},
|
| 74 |
"outputs": [],
|
| 75 |
+
"source": "%%capture\n!pip install -q \\\n \"langgraph>=0.0.15\" \"langchain>=0.3.0\" langchain-text-splitters langchain-community \\\n \"chromadb>=0.4.0\" sentence-transformers \\\n \"transformers>=4.50.0\" accelerate bitsandbytes \\\n gradio huggingface_hub \\\n \"pydantic>=2.0\" python-dotenv openpyxl pypdf \"pandas>=2.0\" jq"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
},
|
| 77 |
{
|
| 78 |
"cell_type": "markdown",
|
|
|
|
| 259 |
"cell_type": "markdown",
|
| 260 |
"id": "37d17f6b",
|
| 261 |
"metadata": {},
|
| 262 |
+
"source": "## 5 · Launch the Gradio App\n\nTwo tabbed scenarios are exposed in a single Gradio interface:\n\n| Tab | Scenario | Agents active |\n|---|---|---|\n| **Stage 1 — Empirical** | No lab results yet | Agent 1 (MedGemma 4B) → Agent 4 (MedGemma 4B + TxGemma 2B) |\n| **Stage 2 — Targeted** | Culture & sensitivity available | Agent 1 (MedGemma 4B) → Agent 2 (MedGemma 4B) → Agent 3 (MedGemma 4B→27B sub) → Agent 4 (MedGemma 4B + TxGemma 2B) |\n\n`demo.launch(share=True)` prints a public Gradio URL — no extra tunnel needed."
|
|
|
|
|
|
|
| 263 |
},
|
| 264 |
{
|
| 265 |
"cell_type": "code",
|
|
|
|
| 267 |
"id": "96ff2d63",
|
| 268 |
"metadata": {},
|
| 269 |
"outputs": [],
|
| 270 |
+
"source": "import json\nimport sys\n\nsys.path.insert(0, \"/kaggle/working/Med-I-C\")\n\n# ── Demo fallback (used when pipeline errors out before models are warm) ──────\n\ndef _demo_result(patient_data: dict, labs_text) -> dict:\n result = {\n \"stage\": \"targeted\" if labs_text else \"empirical\",\n \"creatinine_clearance_ml_min\": 58.3,\n \"intake_notes\": json.dumps({\n \"patient_summary\": (\n f\"{patient_data.get('age_years')}-year-old {patient_data.get('sex')} · \"\n f\"{patient_data.get('suspected_source', 'infection')}\"\n ),\n \"creatinine_clearance_ml_min\": 58.3,\n \"renal_dose_adjustment_needed\": True,\n \"identified_risk_factors\": patient_data.get(\"comorbidities\", []),\n \"infection_severity\": \"moderate\",\n \"recommended_stage\": \"targeted\" if labs_text else \"empirical\",\n }),\n \"recommendation\": {\n \"primary_antibiotic\": \"Ciprofloxacin\",\n \"dose\": \"500 mg\",\n \"route\": \"Oral\",\n \"frequency\": \"Every 12 hours\",\n \"duration\": \"7 days\",\n \"backup_antibiotic\": \"Nitrofurantoin 100 mg MR BD × 5 days\",\n \"rationale\": (\n \"Community-acquired UTI with moderate renal impairment (CrCl 58 mL/min). \"\n \"Ciprofloxacin provides broad Gram-negative coverage. No dose adjustment \"\n \"required above CrCl 30 mL/min.\"\n ),\n \"references\": [\"IDSA UTI Guidelines 2024\", \"EUCAST Breakpoint Tables v16.0\"],\n },\n \"safety_warnings\": [],\n \"errors\": [],\n }\n if labs_text:\n result[\"vision_notes\"] = json.dumps({\n \"specimen_type\": \"urine\",\n \"identified_organisms\": [{\"organism_name\": \"Escherichia coli\", \"significance\": \"pathogen\"}],\n \"susceptibility_results\": [\n {\"organism\": \"E. coli\", \"antibiotic\": \"Ciprofloxacin\", \"mic_value\": 0.25, \"interpretation\": \"S\"},\n {\"organism\": \"E. coli\", \"antibiotic\": \"Nitrofurantoin\", \"mic_value\": 16, \"interpretation\": \"S\"},\n {\"organism\": \"E. coli\", \"antibiotic\": \"Ampicillin\", \"mic_value\": \">32\", \"interpretation\": \"R\"},\n ],\n \"extraction_confidence\": 0.95,\n })\n result[\"trend_notes\"] = json.dumps([{\n \"organism\": \"E. coli\",\n \"antibiotic\": \"Ciprofloxacin\",\n \"risk_level\": \"LOW\",\n \"recommendation\": \"No MIC creep detected — continue current therapy.\",\n }])\n return result\n\n\n# ── Output formatters ─────────────────────────────────────────────────────────\n\ndef _parse_json_field(raw):\n if not raw or raw in (\"No lab data provided\", \"No MIC data available for trend analysis\", \"\"):\n return None\n if isinstance(raw, (dict, list)):\n return raw\n try:\n return json.loads(raw)\n except Exception:\n return None\n\n\ndef format_recommendation(result: dict) -> str:\n lines = [\"## ℞ Recommendation\\n\"]\n rec = result.get(\"recommendation\", {})\n if rec:\n drug = rec.get(\"primary_antibiotic\", \"—\")\n dose = rec.get(\"dose\", \"—\")\n route = rec.get(\"route\", \"—\")\n freq = rec.get(\"frequency\", \"—\")\n dur = rec.get(\"duration\", \"—\")\n lines.append(f\"**Drug:** {drug}\")\n lines.append(\n f\"**Dose:** {dose} · **Route:** {route} \"\n f\" · **Frequency:** {freq} · **Duration:** {dur}\"\n )\n if rec.get(\"backup_antibiotic\"):\n lines.append(f\"**Alternative:** {rec['backup_antibiotic']}\")\n if rec.get(\"rationale\"):\n lines.append(f\"\\n**Clinical rationale:** {rec['rationale']}\")\n if rec.get(\"references\"):\n lines.append(\"\\n**References:**\")\n for ref in rec[\"references\"]:\n lines.append(f\"- {ref}\")\n\n intake = _parse_json_field(result.get(\"intake_notes\", \"\"))\n if isinstance(intake, dict):\n lines.append(\"\\n---\\n## Patient Summary\")\n if intake.get(\"patient_summary\"):\n lines.append(f\"> {intake['patient_summary']}\")\n crcl = result.get(\"creatinine_clearance_ml_min\") or intake.get(\"creatinine_clearance_ml_min\")\n if crcl:\n lines.append(f\"**CrCl:** {float(crcl):.1f} mL/min\")\n if intake.get(\"renal_dose_adjustment_needed\"):\n lines.append(\"⚠ **Renal dose adjustment required**\")\n factors = intake.get(\"identified_risk_factors\", [])\n if factors:\n lines.append(f\"**Risk factors:** {', '.join(factors)}\")\n\n warnings = result.get(\"safety_warnings\", [])\n if warnings:\n lines.append(\"\\n---\\n## ⚠ Safety Warnings\")\n for w in warnings:\n lines.append(f\"- {w}\")\n\n errors = result.get(\"errors\", [])\n if errors:\n lines.append(\"\\n---\\n## Errors\")\n for e in errors:\n lines.append(f\"- {e}\")\n\n return \"\\n\".join(lines)\n\n\ndef format_lab_analysis(result: dict) -> str:\n lines = []\n vision = _parse_json_field(result.get(\"vision_notes\", \"\"))\n trend = _parse_json_field(result.get(\"trend_notes\", \"\"))\n\n if vision is None:\n return \"*No lab data processed.*\"\n\n if isinstance(vision, dict):\n lines.append(\"## Lab Extraction\")\n if vision.get(\"specimen_type\"):\n lines.append(f\"**Specimen:** {vision['specimen_type'].capitalize()}\")\n if vision.get(\"extraction_confidence\") is not None:\n conf = float(vision[\"extraction_confidence\"])\n lines.append(f\"**Extraction confidence:** {conf:.0%}\")\n\n orgs = vision.get(\"identified_organisms\", [])\n if orgs:\n lines.append(\"\\n**Identified organisms:**\")\n for o in orgs:\n name = o.get(\"organism_name\", \"Unknown\")\n sig = o.get(\"significance\", \"\")\n lines.append(f\"- **{name}**\" + (f\" — {sig}\" if sig else \"\"))\n\n sus = vision.get(\"susceptibility_results\", [])\n if sus:\n lines.append(\"\\n**Susceptibility results:**\")\n lines.append(\"| Organism | Antibiotic | MIC (mg/L) | Result |\")\n lines.append(\"|---|---|---|---|\")\n for s in sus:\n interp = s.get(\"interpretation\", \"\")\n icon = {\"S\": \"✓ S\", \"R\": \"✗ R\", \"I\": \"~ I\"}.get(interp.upper(), interp)\n lines.append(\n f\"| {s.get('organism','')} | {s.get('antibiotic','')} \"\n f\"| {s.get('mic_value','')} | {icon} |\"\n )\n\n if trend:\n items = trend if isinstance(trend, list) else [trend]\n lines.append(\"\\n## MIC Trend Analysis\")\n for item in items:\n if not isinstance(item, dict):\n lines.append(str(item))\n continue\n risk = item.get(\"risk_level\", \"UNKNOWN\").upper()\n icon = {\"HIGH\": \"🚨\", \"MODERATE\": \"⚠\"}.get(risk, \"✓\")\n org = item.get(\"organism\", \"\")\n ab = item.get(\"antibiotic\", \"\")\n label = f\"{org} / {ab} — \" if (org or ab) else \"\"\n lines.append(f\"**{icon} {label}{risk}** — {item.get('recommendation', '')}\")\n\n return \"\\n\".join(lines) if lines else \"*No lab analysis available.*\"\n\n\n# ── Pipeline runner helpers ───────────────────────────────────────────────────\n\ndef _build_patient_data(age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str):\n return {\n \"age_years\": float(age),\n \"weight_kg\": float(weight),\n \"height_cm\": float(height),\n \"sex\": sex,\n \"serum_creatinine_mg_dl\": float(creatinine),\n \"infection_site\": infection_site,\n \"suspected_source\": suspected_source or f\"{infection_site} infection\",\n \"medications\": [m.strip() for m in medications_str.split(\"\\n\") if m.strip()],\n \"allergies\": [a.strip() for a in allergies_str.split(\"\\n\") if a.strip()],\n \"comorbidities\":[c.strip() for c in comorbidities_str.split(\"\\n\") if c.strip()],\n }\n\n\ndef run_empirical_scenario(age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str):\n \"\"\"Stage 1 — Empirical: no lab results.\n Active models: MedGemma 4B (Agent 1) → MedGemma 4B + TxGemma 2B (Agent 4).\n \"\"\"\n patient_data = _build_patient_data(\n age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str,\n )\n try:\n from src.graph import run_pipeline\n result = run_pipeline(patient_data, labs_raw_text=None)\n except Exception as exc:\n result = _demo_result(patient_data, None)\n result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n return format_recommendation(result)\n\n\ndef run_targeted_scenario(age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str,\n labs_text):\n \"\"\"Stage 2 — Targeted: lab culture & sensitivity available.\n Active models: MedGemma 4B (Agents 1, 2) → MedGemma 4B→27B sub (Agent 3)\n → MedGemma 4B + TxGemma 2B (Agent 4).\n \"\"\"\n patient_data = _build_patient_data(\n age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str,\n )\n labs = labs_text.strip() if labs_text else None\n try:\n from src.graph import run_pipeline\n result = run_pipeline(patient_data, labs_raw_text=labs)\n except Exception as exc:\n result = _demo_result(patient_data, labs)\n result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n return format_recommendation(result), format_lab_analysis(result)\n\n\nprint(\"Helper functions loaded.\")"
|
|
|
|
|
|
|
|
|
|
| 271 |
},
|
| 272 |
{
|
| 273 |
"cell_type": "code",
|
|
|
|
| 275 |
"id": "ea6b1788",
|
| 276 |
"metadata": {},
|
| 277 |
"outputs": [],
|
| 278 |
+
"source": "import gradio as gr\n\nINFECTION_SITES = [\"urinary\", \"respiratory\", \"bloodstream\", \"skin\", \"intra-abdominal\", \"CNS\", \"other\"]\n\n\ndef _patient_inputs():\n \"\"\"Create patient-demographics input widgets inside the current gr.Blocks context.\"\"\"\n with gr.Row():\n age = gr.Number(label=\"Age (years)\", value=65, minimum=0, maximum=120, precision=0)\n weight = gr.Number(label=\"Weight (kg)\", value=70.0, minimum=1, maximum=300)\n height = gr.Number(label=\"Height (cm)\", value=170.0,minimum=50, maximum=250)\n with gr.Row():\n sex = gr.Dropdown(label=\"Biological sex\", choices=[\"male\", \"female\"], value=\"male\")\n creatinine = gr.Number(label=\"Serum Creatinine (mg/dL)\", value=1.2, minimum=0.1, maximum=20.0)\n infection_site = gr.Dropdown(label=\"Infection site\", choices=INFECTION_SITES, value=\"urinary\")\n suspected_source = gr.Textbox(label=\"Suspected source\",\n placeholder=\"e.g., community-acquired UTI\")\n with gr.Row():\n medications = gr.Textbox(label=\"Current medications (one per line)\",\n placeholder=\"Metformin\\nLisinopril\", lines=3)\n allergies = gr.Textbox(label=\"Drug allergies (one per line)\",\n placeholder=\"Penicillin\\nSulfa\", lines=3)\n comorbidities = gr.Textbox(label=\"Comorbidities / MDR risk factors (one per line)\",\n placeholder=\"Diabetes\\nCKD\\nPrior MRSA\", lines=3)\n return [age, weight, height, sex, creatinine, infection_site,\n suspected_source, medications, allergies, comorbidities]\n\n\nwith gr.Blocks(title=\"AMR-Guard · Med-I-C\", theme=gr.themes.Soft()) as demo:\n\n gr.Markdown(\"\"\"\n# ⚕ AMR-Guard — Infection Lifecycle Orchestrator\n\n**Multi-Agent Clinical Decision Support for Antimicrobial Stewardship**\n\n| Model | Agent(s) | Role |\n|---|---|---|\n| `google/medgemma-4b-it` | 1, 2, 4 | Intake · Lab extraction · Final Rx |\n| `google/medgemma-4b-it` (27B sub on T4) | 3 | MIC trend analysis |\n| `google/txgemma-2b-predict` (9B sub on T4) | 4 (safety) | Drug interaction screening |\n\n> ⚠ **Research demo only** — not validated for clinical use. All output must be reviewed by a licensed clinician.\n---\n\"\"\")\n\n with gr.Tabs():\n\n # ──────────────────────────────────────────────────────────────────────\n # TAB 1 — Stage 1: Empirical (no lab results)\n # ──────────────────────────────────────────────────────────────────────\n with gr.Tab(\"Stage 1 — Empirical (no lab results)\"):\n gr.Markdown(\"\"\"\n**Scenario:** Patient presents without culture / sensitivity data.\n\n**Pipeline:** Agent 1 — *Intake Historian* (MedGemma 4B IT) → Agent 4 — *Clinical Pharmacologist* (MedGemma 4B IT + TxGemma 2B)\n\"\"\")\n emp_inputs = _patient_inputs()\n emp_btn = gr.Button(\"Run Empirical Pipeline\", variant=\"primary\")\n emp_output = gr.Markdown(label=\"Recommendation\")\n\n emp_btn.click(\n fn=run_empirical_scenario,\n inputs=emp_inputs,\n outputs=emp_output,\n )\n\n # ──────────────────────────────────────────────────────────────────────\n # TAB 2 — Stage 2: Targeted (culture & sensitivity available)\n # ──────────────────────────────────────────────────────────────────────\n with gr.Tab(\"Stage 2 — Targeted (lab results available)\"):\n gr.Markdown(\"\"\"\n**Scenario:** Culture & sensitivity report (any language) is available.\n\n**Pipeline:** Agent 1 (MedGemma 4B IT) → Agent 2 — *Vision Specialist* (MedGemma 4B IT) → Agent 3 — *Trend Analyst* (MedGemma 27B→4B sub) → Agent 4 (MedGemma 4B IT + TxGemma 2B)\n\"\"\")\n tgt_inputs = _patient_inputs()\n tgt_labs = gr.Textbox(\n label=\"Lab / Culture Report — paste text (any language)\",\n placeholder=(\n \"Organism: Escherichia coli\\n\"\n \"Ciprofloxacin: S MIC 0.25 mg/L\\n\"\n \"Nitrofurantoin: S MIC 16 mg/L\\n\"\n \"Ampicillin: R MIC >32 mg/L\"\n ),\n lines=6,\n )\n tgt_btn = gr.Button(\"Run Targeted Pipeline\", variant=\"primary\")\n\n with gr.Row():\n tgt_rec_output = gr.Markdown(label=\"Recommendation\")\n tgt_lab_output = gr.Markdown(label=\"Lab Analysis & MIC Trend\")\n\n tgt_btn.click(\n fn=run_targeted_scenario,\n inputs=tgt_inputs + [tgt_labs],\n outputs=[tgt_rec_output, tgt_lab_output],\n )\n\n gr.Markdown(\"\"\"\n---\n**Knowledge bases:** EUCAST v16.0 · WHO AWaRe 2024 · IDSA AMR Guidance 2024 · ATLAS Surveillance · WHO GLASS · DDInter 2.0 \n**Inference:** HuggingFace Transformers · 4-bit quantization · Kaggle T4 GPU\n\"\"\")\n\nprint(\"Gradio app defined. Run the next cell to launch.\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
},
|
| 280 |
{
|
| 281 |
"cell_type": "code",
|
|
|
|
| 283 |
"id": "00ecfb17",
|
| 284 |
"metadata": {},
|
| 285 |
"outputs": [],
|
| 286 |
+
"source": "# share=True creates a public Gradio URL (works out-of-the-box on Kaggle — no localtunnel needed).\n# The URL is printed below and stays live for ~72 hours.\ndemo.launch(share=True, quiet=True)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
}
|
| 288 |
],
|
| 289 |
"metadata": {
|
|
|
|
| 299 |
},
|
| 300 |
"nbformat": 4,
|
| 301 |
"nbformat_minor": 5
|
| 302 |
+
}
|
src/agents.py
CHANGED
|
@@ -1,18 +1,15 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
- Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
|
| 9 |
"""
|
| 10 |
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
import json
|
| 14 |
import logging
|
| 15 |
-
from typing import
|
| 16 |
|
| 17 |
from .config import get_settings
|
| 18 |
from .loader import run_inference, TextModelName
|
|
@@ -40,36 +37,12 @@ from .utils import (
|
|
| 40 |
logger = logging.getLogger(__name__)
|
| 41 |
|
| 42 |
|
| 43 |
-
# =============================================================================
|
| 44 |
-
# AGENT 1: INTAKE HISTORIAN
|
| 45 |
-
# =============================================================================
|
| 46 |
-
|
| 47 |
def run_intake_historian(state: InfectionState) -> InfectionState:
|
| 48 |
-
"""
|
| 49 |
-
Agent 1: Parse patient data, calculate CrCl, identify risk factors.
|
| 50 |
-
|
| 51 |
-
Input state fields used:
|
| 52 |
-
- age_years, weight_kg, height_cm, sex
|
| 53 |
-
- serum_creatinine_mg_dl
|
| 54 |
-
- medications, allergies, comorbidities
|
| 55 |
-
- suspected_source, infection_site
|
| 56 |
-
|
| 57 |
-
Output state fields updated:
|
| 58 |
-
- creatinine_clearance_ml_min
|
| 59 |
-
- intake_notes
|
| 60 |
-
- stage (empirical/targeted)
|
| 61 |
-
- route_to_vision
|
| 62 |
-
"""
|
| 63 |
logger.info("Running Intake Historian agent...")
|
| 64 |
|
| 65 |
-
# Calculate CrCl if we have the required data
|
| 66 |
crcl = None
|
| 67 |
-
if all([
|
| 68 |
-
state.get("age_years"),
|
| 69 |
-
state.get("weight_kg"),
|
| 70 |
-
state.get("serum_creatinine_mg_dl"),
|
| 71 |
-
state.get("sex"),
|
| 72 |
-
]):
|
| 73 |
try:
|
| 74 |
crcl = calculate_crcl(
|
| 75 |
age_years=state["age_years"],
|
|
@@ -85,20 +58,14 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
|
|
| 85 |
logger.warning(f"Could not calculate CrCl: {e}")
|
| 86 |
state.setdefault("errors", []).append(f"CrCl calculation error: {e}")
|
| 87 |
|
| 88 |
-
# Build patient data string for prompt
|
| 89 |
patient_data = _format_patient_data(state)
|
| 90 |
-
|
| 91 |
-
# Get RAG context
|
| 92 |
query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}"
|
| 93 |
rag_context = get_context_for_agent(
|
| 94 |
agent_name="intake_historian",
|
| 95 |
query=query,
|
| 96 |
-
patient_context={
|
| 97 |
-
"pathogen_type": state.get("suspected_source"),
|
| 98 |
-
},
|
| 99 |
)
|
| 100 |
|
| 101 |
-
# Format the prompt
|
| 102 |
prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
|
| 103 |
patient_data=patient_data,
|
| 104 |
medications=', '.join(state.get('medications', [])) or 'None reported',
|
|
@@ -108,34 +75,20 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
|
|
| 108 |
rag_context=rag_context,
|
| 109 |
)}"
|
| 110 |
|
| 111 |
-
# Run inference
|
| 112 |
try:
|
| 113 |
-
response = run_inference(
|
| 114 |
-
prompt=prompt,
|
| 115 |
-
model_name="medgemma_4b",
|
| 116 |
-
max_new_tokens=1024,
|
| 117 |
-
temperature=0.2,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# Parse response
|
| 121 |
parsed = safe_json_parse(response)
|
| 122 |
if parsed:
|
| 123 |
state["intake_notes"] = json.dumps(parsed, indent=2)
|
| 124 |
-
|
| 125 |
-
# Update state from parsed response
|
| 126 |
if parsed.get("creatinine_clearance_ml_min") and crcl is None:
|
| 127 |
state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"]
|
| 128 |
-
|
| 129 |
-
# Determine stage
|
| 130 |
-
recommended_stage = parsed.get("recommended_stage", "empirical")
|
| 131 |
-
state["stage"] = recommended_stage
|
| 132 |
-
|
| 133 |
-
# Route to vision if we have lab data to process
|
| 134 |
-
state["route_to_vision"] = bool(state.get("labs_raw_text"))
|
| 135 |
else:
|
| 136 |
state["intake_notes"] = response
|
| 137 |
state["stage"] = "empirical"
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
|
| 140 |
except Exception as e:
|
| 141 |
logger.error(f"Intake Historian error: {e}")
|
|
@@ -147,23 +100,8 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
|
|
| 147 |
return state
|
| 148 |
|
| 149 |
|
| 150 |
-
# =============================================================================
|
| 151 |
-
# AGENT 2: VISION SPECIALIST
|
| 152 |
-
# =============================================================================
|
| 153 |
-
|
| 154 |
def run_vision_specialist(state: InfectionState) -> InfectionState:
|
| 155 |
-
"""
|
| 156 |
-
Agent 2: Extract structured data from lab reports (text, images, PDFs).
|
| 157 |
-
|
| 158 |
-
Input state fields used:
|
| 159 |
-
- labs_raw_text (extracted text from lab report)
|
| 160 |
-
|
| 161 |
-
Output state fields updated:
|
| 162 |
-
- labs_parsed
|
| 163 |
-
- mic_data
|
| 164 |
-
- vision_notes
|
| 165 |
-
- route_to_trend_analyst
|
| 166 |
-
"""
|
| 167 |
logger.info("Running Vision Specialist agent...")
|
| 168 |
|
| 169 |
labs_raw = state.get("labs_raw_text", "")
|
|
@@ -173,68 +111,54 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
|
|
| 173 |
state["route_to_trend_analyst"] = False
|
| 174 |
return state
|
| 175 |
|
| 176 |
-
#
|
| 177 |
language = "English (assumed)"
|
| 178 |
-
|
| 179 |
-
# Get RAG context for lab interpretation
|
| 180 |
rag_context = get_context_for_agent(
|
| 181 |
agent_name="vision_specialist",
|
| 182 |
query="culture sensitivity susceptibility interpretation",
|
| 183 |
patient_context={},
|
| 184 |
)
|
| 185 |
|
| 186 |
-
# Format the prompt
|
| 187 |
prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
|
| 188 |
report_content=labs_raw,
|
| 189 |
source_format='text',
|
| 190 |
language=language,
|
| 191 |
)}"
|
| 192 |
|
| 193 |
-
# Run inference
|
| 194 |
try:
|
| 195 |
-
response = run_inference(
|
| 196 |
-
prompt=prompt,
|
| 197 |
-
model_name="medgemma_4b",
|
| 198 |
-
max_new_tokens=2048,
|
| 199 |
-
temperature=0.1,
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
# Parse response
|
| 203 |
parsed = safe_json_parse(response)
|
| 204 |
if parsed:
|
| 205 |
state["vision_notes"] = json.dumps(parsed, indent=2)
|
| 206 |
|
| 207 |
-
# Extract organisms and susceptibility data
|
| 208 |
organisms = parsed.get("identified_organisms", [])
|
| 209 |
susceptibility = parsed.get("susceptibility_results", [])
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
"
|
| 216 |
-
"
|
| 217 |
-
"
|
| 218 |
-
"mic_unit": result.get("mic_unit", "mg/L"),
|
| 219 |
-
"interpretation": result.get("interpretation"),
|
| 220 |
}
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
state["mic_data"] = mic_data
|
| 224 |
-
state["labs_parsed"] = [
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
| 231 |
state["route_to_trend_analyst"] = len(mic_data) > 0
|
| 232 |
|
| 233 |
-
# Check for critical findings
|
| 234 |
critical = parsed.get("critical_findings", [])
|
| 235 |
if critical:
|
| 236 |
state.setdefault("safety_warnings", []).extend(critical)
|
| 237 |
-
|
| 238 |
else:
|
| 239 |
state["vision_notes"] = response
|
| 240 |
state["route_to_trend_analyst"] = False
|
|
@@ -249,23 +173,8 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
|
|
| 249 |
return state
|
| 250 |
|
| 251 |
|
| 252 |
-
# =============================================================================
|
| 253 |
-
# AGENT 3: TREND ANALYST
|
| 254 |
-
# =============================================================================
|
| 255 |
-
|
| 256 |
def run_trend_analyst(state: InfectionState) -> InfectionState:
|
| 257 |
-
"""
|
| 258 |
-
Agent 3: Analyze MIC trends and detect resistance velocity.
|
| 259 |
-
|
| 260 |
-
Input state fields used:
|
| 261 |
-
- mic_data (current MIC readings)
|
| 262 |
-
- Historical MIC data (if available)
|
| 263 |
-
|
| 264 |
-
Output state fields updated:
|
| 265 |
-
- mic_trend_summary
|
| 266 |
-
- trend_notes
|
| 267 |
-
- safety_warnings (if high risk detected)
|
| 268 |
-
"""
|
| 269 |
logger.info("Running Trend Analyst agent...")
|
| 270 |
|
| 271 |
mic_data = state.get("mic_data", [])
|
|
@@ -274,14 +183,12 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
|
|
| 274 |
state["trend_notes"] = "No MIC data available for trend analysis"
|
| 275 |
return state
|
| 276 |
|
| 277 |
-
# For each organism-antibiotic pair, analyze trends
|
| 278 |
trend_results = []
|
| 279 |
|
| 280 |
for mic in mic_data:
|
| 281 |
organism = mic.get("organism", "Unknown")
|
| 282 |
antibiotic = mic.get("antibiotic", "Unknown")
|
| 283 |
|
| 284 |
-
# Get RAG context for breakpoints
|
| 285 |
rag_context = get_context_for_agent(
|
| 286 |
agent_name="trend_analyst",
|
| 287 |
query=f"breakpoint {organism} {antibiotic}",
|
|
@@ -292,10 +199,9 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
|
|
| 292 |
},
|
| 293 |
)
|
| 294 |
|
| 295 |
-
#
|
| 296 |
mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}]
|
| 297 |
|
| 298 |
-
# Format prompt
|
| 299 |
prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format(
|
| 300 |
organism=organism,
|
| 301 |
antibiotic=antibiotic,
|
|
@@ -305,18 +211,16 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
|
|
| 305 |
)}"
|
| 306 |
|
| 307 |
try:
|
|
|
|
| 308 |
response = run_inference(
|
| 309 |
prompt=prompt,
|
| 310 |
-
model_name="medgemma_27b",
|
| 311 |
max_new_tokens=1024,
|
| 312 |
temperature=0.2,
|
| 313 |
)
|
| 314 |
-
|
| 315 |
parsed = safe_json_parse(response)
|
| 316 |
if parsed:
|
| 317 |
trend_results.append(parsed)
|
| 318 |
-
|
| 319 |
-
# Add safety warning if high/critical risk
|
| 320 |
risk_level = parsed.get("risk_level", "LOW")
|
| 321 |
if risk_level in ["HIGH", "CRITICAL"]:
|
| 322 |
warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}"
|
|
@@ -328,10 +232,8 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
|
|
| 328 |
logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}")
|
| 329 |
trend_results.append({"error": str(e)})
|
| 330 |
|
| 331 |
-
# Summarize trends
|
| 332 |
state["trend_notes"] = json.dumps(trend_results, indent=2)
|
| 333 |
|
| 334 |
-
# Create summary
|
| 335 |
high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"])
|
| 336 |
state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}"
|
| 337 |
|
|
@@ -339,43 +241,21 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
|
|
| 339 |
return state
|
| 340 |
|
| 341 |
|
| 342 |
-
# =============================================================================
|
| 343 |
-
# AGENT 4: CLINICAL PHARMACOLOGIST
|
| 344 |
-
# =============================================================================
|
| 345 |
-
|
| 346 |
def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
| 347 |
-
"""
|
| 348 |
-
Agent 4: Generate final antibiotic recommendation with safety checks.
|
| 349 |
-
|
| 350 |
-
Input state fields used:
|
| 351 |
-
- intake_notes, vision_notes, trend_notes
|
| 352 |
-
- age_years, weight_kg, creatinine_clearance_ml_min
|
| 353 |
-
- allergies, medications
|
| 354 |
-
- infection_site, suspected_source
|
| 355 |
-
|
| 356 |
-
Output state fields updated:
|
| 357 |
-
- recommendation
|
| 358 |
-
- pharmacology_notes
|
| 359 |
-
- safety_warnings (additional alerts)
|
| 360 |
-
"""
|
| 361 |
logger.info("Running Clinical Pharmacologist agent...")
|
| 362 |
|
| 363 |
-
# Gather all previous agent outputs
|
| 364 |
intake_summary = state.get("intake_notes", "No intake data")
|
| 365 |
lab_results = state.get("vision_notes", "No lab data")
|
| 366 |
trend_analysis = state.get("trend_notes", "No trend data")
|
| 367 |
|
| 368 |
-
# Get RAG context
|
| 369 |
query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation"
|
| 370 |
rag_context = get_context_for_agent(
|
| 371 |
agent_name="clinical_pharmacologist",
|
| 372 |
query=query,
|
| 373 |
-
patient_context={
|
| 374 |
-
"proposed_antibiotic": None, # Will be determined by agent
|
| 375 |
-
},
|
| 376 |
)
|
| 377 |
|
| 378 |
-
# Format prompt
|
| 379 |
prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
|
| 380 |
intake_summary=intake_summary,
|
| 381 |
lab_results=lab_results,
|
|
@@ -392,18 +272,11 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
|
| 392 |
)}"
|
| 393 |
|
| 394 |
try:
|
| 395 |
-
response = run_inference(
|
| 396 |
-
prompt=prompt,
|
| 397 |
-
model_name="medgemma_4b",
|
| 398 |
-
max_new_tokens=2048,
|
| 399 |
-
temperature=0.2,
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
parsed = safe_json_parse(response)
|
| 403 |
if parsed:
|
| 404 |
state["pharmacology_notes"] = json.dumps(parsed, indent=2)
|
| 405 |
|
| 406 |
-
# Build recommendation
|
| 407 |
primary = parsed.get("primary_recommendation", {})
|
| 408 |
recommendation = {
|
| 409 |
"primary_antibiotic": primary.get("antibiotic"),
|
|
@@ -416,19 +289,16 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
|
| 416 |
"safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])],
|
| 417 |
}
|
| 418 |
|
| 419 |
-
# Add alternative if provided
|
| 420 |
alt = parsed.get("alternative_recommendation", {})
|
| 421 |
if alt.get("antibiotic"):
|
| 422 |
recommendation["backup_antibiotic"] = alt.get("antibiotic")
|
| 423 |
|
| 424 |
state["recommendation"] = recommendation
|
| 425 |
|
| 426 |
-
# Add safety alerts to state
|
| 427 |
for alert in parsed.get("safety_alerts", []):
|
| 428 |
if alert.get("level") in ["WARNING", "CRITICAL"]:
|
| 429 |
state.setdefault("safety_warnings", []).append(alert.get("message"))
|
| 430 |
|
| 431 |
-
# Run TxGemma safety check (optional)
|
| 432 |
if primary.get("antibiotic"):
|
| 433 |
safety_result = _run_txgemma_safety_check(
|
| 434 |
antibiotic=primary.get("antibiotic"),
|
|
@@ -441,7 +311,6 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
|
| 441 |
)
|
| 442 |
if safety_result:
|
| 443 |
state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}")
|
| 444 |
-
|
| 445 |
else:
|
| 446 |
state["pharmacology_notes"] = response
|
| 447 |
state["recommendation"] = {"rationale": response}
|
|
@@ -455,12 +324,8 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
|
| 455 |
return state
|
| 456 |
|
| 457 |
|
| 458 |
-
# =============================================================================
|
| 459 |
-
# HELPER FUNCTIONS
|
| 460 |
-
# =============================================================================
|
| 461 |
-
|
| 462 |
def _format_patient_data(state: InfectionState) -> str:
|
| 463 |
-
"""Format patient
|
| 464 |
lines = []
|
| 465 |
|
| 466 |
if state.get("patient_id"):
|
|
@@ -505,11 +370,7 @@ def _run_txgemma_safety_check(
|
|
| 505 |
crcl: Optional[float],
|
| 506 |
medications: list,
|
| 507 |
) -> Optional[str]:
|
| 508 |
-
"""
|
| 509 |
-
Run TxGemma safety check (supplementary).
|
| 510 |
-
|
| 511 |
-
TxGemma is used only for safety validation, not primary recommendations.
|
| 512 |
-
"""
|
| 513 |
try:
|
| 514 |
prompt = TXGEMMA_SAFETY_PROMPT.format(
|
| 515 |
antibiotic=antibiotic,
|
|
@@ -520,25 +381,13 @@ def _run_txgemma_safety_check(
|
|
| 520 |
crcl=crcl or "Unknown",
|
| 521 |
medications=", ".join(medications) if medications else "None",
|
| 522 |
)
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
prompt=prompt,
|
| 526 |
-
model_name="txgemma_9b", # Agent 4 safety: TxGemma 9B per PLAN.md (env maps to 2B on limited GPU)
|
| 527 |
-
max_new_tokens=256,
|
| 528 |
-
temperature=0.1,
|
| 529 |
-
)
|
| 530 |
-
|
| 531 |
-
return response
|
| 532 |
-
|
| 533 |
except Exception as e:
|
| 534 |
logger.warning(f"TxGemma safety check failed: {e}")
|
| 535 |
return None
|
| 536 |
|
| 537 |
|
| 538 |
-
# =============================================================================
|
| 539 |
-
# AGENT REGISTRY
|
| 540 |
-
# =============================================================================
|
| 541 |
-
|
| 542 |
AGENTS = {
|
| 543 |
"intake_historian": run_intake_historian,
|
| 544 |
"vision_specialist": run_vision_specialist,
|
|
@@ -548,27 +397,7 @@ AGENTS = {
|
|
| 548 |
|
| 549 |
|
| 550 |
def run_agent(agent_name: str, state: InfectionState) -> InfectionState:
|
| 551 |
-
"""
|
| 552 |
-
Run a specific agent by name.
|
| 553 |
-
|
| 554 |
-
Args:
|
| 555 |
-
agent_name: Name of the agent to run
|
| 556 |
-
state: Current infection state
|
| 557 |
-
|
| 558 |
-
Returns:
|
| 559 |
-
Updated infection state
|
| 560 |
-
"""
|
| 561 |
if agent_name not in AGENTS:
|
| 562 |
raise ValueError(f"Unknown agent: {agent_name}")
|
| 563 |
-
|
| 564 |
return AGENTS[agent_name](state)
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
__all__ = [
|
| 568 |
-
"run_intake_historian",
|
| 569 |
-
"run_vision_specialist",
|
| 570 |
-
"run_trend_analyst",
|
| 571 |
-
"run_clinical_pharmacologist",
|
| 572 |
-
"run_agent",
|
| 573 |
-
"AGENTS",
|
| 574 |
-
]
|
|
|
|
| 1 |
"""
|
| 2 |
+
Four-agent pipeline for the infection lifecycle workflow.
|
| 3 |
|
| 4 |
+
Agent 1 - Intake Historian: parse patient data, calculate CrCl, identify AMR risk factors
|
| 5 |
+
Agent 2 - Vision Specialist: extract organisms and MIC values from lab reports
|
| 6 |
+
Agent 3 - Trend Analyst: detect MIC creep and resistance velocity
|
| 7 |
+
Agent 4 - Clinical Pharmacologist: generate final antibiotic recommendation with safety checks
|
|
|
|
| 8 |
"""
|
| 9 |
|
|
|
|
|
|
|
| 10 |
import json
|
| 11 |
import logging
|
| 12 |
+
from typing import Optional
|
| 13 |
|
| 14 |
from .config import get_settings
|
| 15 |
from .loader import run_inference, TextModelName
|
|
|
|
| 37 |
logger = logging.getLogger(__name__)
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def run_intake_historian(state: InfectionState) -> InfectionState:
|
| 41 |
+
"""Parse patient data, calculate CrCl, identify MDR risk factors, and set the treatment stage."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
logger.info("Running Intake Historian agent...")
|
| 43 |
|
|
|
|
| 44 |
crcl = None
|
| 45 |
+
if all([state.get("age_years"), state.get("weight_kg"), state.get("serum_creatinine_mg_dl"), state.get("sex")]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
try:
|
| 47 |
crcl = calculate_crcl(
|
| 48 |
age_years=state["age_years"],
|
|
|
|
| 58 |
logger.warning(f"Could not calculate CrCl: {e}")
|
| 59 |
state.setdefault("errors", []).append(f"CrCl calculation error: {e}")
|
| 60 |
|
|
|
|
| 61 |
patient_data = _format_patient_data(state)
|
|
|
|
|
|
|
| 62 |
query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}"
|
| 63 |
rag_context = get_context_for_agent(
|
| 64 |
agent_name="intake_historian",
|
| 65 |
query=query,
|
| 66 |
+
patient_context={"pathogen_type": state.get("suspected_source")},
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
|
|
|
|
| 69 |
prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
|
| 70 |
patient_data=patient_data,
|
| 71 |
medications=', '.join(state.get('medications', [])) or 'None reported',
|
|
|
|
| 75 |
rag_context=rag_context,
|
| 76 |
)}"
|
| 77 |
|
|
|
|
| 78 |
try:
|
| 79 |
+
response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=1024, temperature=0.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
parsed = safe_json_parse(response)
|
| 81 |
if parsed:
|
| 82 |
state["intake_notes"] = json.dumps(parsed, indent=2)
|
|
|
|
|
|
|
| 83 |
if parsed.get("creatinine_clearance_ml_min") and crcl is None:
|
| 84 |
state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"]
|
| 85 |
+
state["stage"] = parsed.get("recommended_stage", "empirical")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
else:
|
| 87 |
state["intake_notes"] = response
|
| 88 |
state["stage"] = "empirical"
|
| 89 |
+
|
| 90 |
+
# Route to vision only if lab text was provided
|
| 91 |
+
state["route_to_vision"] = bool(state.get("labs_raw_text"))
|
| 92 |
|
| 93 |
except Exception as e:
|
| 94 |
logger.error(f"Intake Historian error: {e}")
|
|
|
|
| 100 |
return state
|
| 101 |
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
def run_vision_specialist(state: InfectionState) -> InfectionState:
|
| 104 |
+
"""Extract pathogen names, MIC values, and S/I/R interpretations from lab report text."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
logger.info("Running Vision Specialist agent...")
|
| 106 |
|
| 107 |
labs_raw = state.get("labs_raw_text", "")
|
|
|
|
| 111 |
state["route_to_trend_analyst"] = False
|
| 112 |
return state
|
| 113 |
|
| 114 |
+
# Language detection is not implemented; we assume English or instruct the model to translate
|
| 115 |
language = "English (assumed)"
|
|
|
|
|
|
|
| 116 |
rag_context = get_context_for_agent(
|
| 117 |
agent_name="vision_specialist",
|
| 118 |
query="culture sensitivity susceptibility interpretation",
|
| 119 |
patient_context={},
|
| 120 |
)
|
| 121 |
|
|
|
|
| 122 |
prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
|
| 123 |
report_content=labs_raw,
|
| 124 |
source_format='text',
|
| 125 |
language=language,
|
| 126 |
)}"
|
| 127 |
|
|
|
|
| 128 |
try:
|
| 129 |
+
response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
parsed = safe_json_parse(response)
|
| 131 |
if parsed:
|
| 132 |
state["vision_notes"] = json.dumps(parsed, indent=2)
|
| 133 |
|
|
|
|
| 134 |
organisms = parsed.get("identified_organisms", [])
|
| 135 |
susceptibility = parsed.get("susceptibility_results", [])
|
| 136 |
|
| 137 |
+
mic_data = [
|
| 138 |
+
{
|
| 139 |
+
"organism": normalize_organism_name(r.get("organism", "")),
|
| 140 |
+
"antibiotic": normalize_antibiotic_name(r.get("antibiotic", "")),
|
| 141 |
+
"mic_value": str(r.get("mic_value", "")),
|
| 142 |
+
"mic_unit": r.get("mic_unit", "mg/L"),
|
| 143 |
+
"interpretation": r.get("interpretation"),
|
|
|
|
|
|
|
| 144 |
}
|
| 145 |
+
for r in susceptibility
|
| 146 |
+
]
|
| 147 |
|
| 148 |
state["mic_data"] = mic_data
|
| 149 |
+
state["labs_parsed"] = [
|
| 150 |
+
{
|
| 151 |
+
"name": org.get("organism_name", "Unknown"),
|
| 152 |
+
"value": org.get("colony_count", ""),
|
| 153 |
+
"flag": "pathogen" if org.get("significance") == "pathogen" else None,
|
| 154 |
+
}
|
| 155 |
+
for org in organisms
|
| 156 |
+
]
|
| 157 |
state["route_to_trend_analyst"] = len(mic_data) > 0
|
| 158 |
|
|
|
|
| 159 |
critical = parsed.get("critical_findings", [])
|
| 160 |
if critical:
|
| 161 |
state.setdefault("safety_warnings", []).extend(critical)
|
|
|
|
| 162 |
else:
|
| 163 |
state["vision_notes"] = response
|
| 164 |
state["route_to_trend_analyst"] = False
|
|
|
|
| 173 |
return state
|
| 174 |
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def run_trend_analyst(state: InfectionState) -> InfectionState:
|
| 177 |
+
"""Analyze MIC trends per organism-antibiotic pair and flag high-risk creep."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
logger.info("Running Trend Analyst agent...")
|
| 179 |
|
| 180 |
mic_data = state.get("mic_data", [])
|
|
|
|
| 183 |
state["trend_notes"] = "No MIC data available for trend analysis"
|
| 184 |
return state
|
| 185 |
|
|
|
|
| 186 |
trend_results = []
|
| 187 |
|
| 188 |
for mic in mic_data:
|
| 189 |
organism = mic.get("organism", "Unknown")
|
| 190 |
antibiotic = mic.get("antibiotic", "Unknown")
|
| 191 |
|
|
|
|
| 192 |
rag_context = get_context_for_agent(
|
| 193 |
agent_name="trend_analyst",
|
| 194 |
query=f"breakpoint {organism} {antibiotic}",
|
|
|
|
| 199 |
},
|
| 200 |
)
|
| 201 |
|
| 202 |
+
# Single time-point history — trend analysis requires historical data in production
|
| 203 |
mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}]
|
| 204 |
|
|
|
|
| 205 |
prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format(
|
| 206 |
organism=organism,
|
| 207 |
antibiotic=antibiotic,
|
|
|
|
| 211 |
)}"
|
| 212 |
|
| 213 |
try:
|
| 214 |
+
# Agent 3 is designed for MedGemma 27B; on limited GPU the env var maps this to 4B
|
| 215 |
response = run_inference(
|
| 216 |
prompt=prompt,
|
| 217 |
+
model_name="medgemma_27b",
|
| 218 |
max_new_tokens=1024,
|
| 219 |
temperature=0.2,
|
| 220 |
)
|
|
|
|
| 221 |
parsed = safe_json_parse(response)
|
| 222 |
if parsed:
|
| 223 |
trend_results.append(parsed)
|
|
|
|
|
|
|
| 224 |
risk_level = parsed.get("risk_level", "LOW")
|
| 225 |
if risk_level in ["HIGH", "CRITICAL"]:
|
| 226 |
warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}"
|
|
|
|
| 232 |
logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}")
|
| 233 |
trend_results.append({"error": str(e)})
|
| 234 |
|
|
|
|
| 235 |
state["trend_notes"] = json.dumps(trend_results, indent=2)
|
| 236 |
|
|
|
|
| 237 |
high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"])
|
| 238 |
state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}"
|
| 239 |
|
|
|
|
| 241 |
return state
|
| 242 |
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
| 245 |
+
"""Synthesize all agent outputs into a final antibiotic recommendation with safety checks."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
logger.info("Running Clinical Pharmacologist agent...")
|
| 247 |
|
|
|
|
| 248 |
intake_summary = state.get("intake_notes", "No intake data")
|
| 249 |
lab_results = state.get("vision_notes", "No lab data")
|
| 250 |
trend_analysis = state.get("trend_notes", "No trend data")
|
| 251 |
|
|
|
|
| 252 |
query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation"
|
| 253 |
rag_context = get_context_for_agent(
|
| 254 |
agent_name="clinical_pharmacologist",
|
| 255 |
query=query,
|
| 256 |
+
patient_context={"proposed_antibiotic": None},
|
|
|
|
|
|
|
| 257 |
)
|
| 258 |
|
|
|
|
| 259 |
prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
|
| 260 |
intake_summary=intake_summary,
|
| 261 |
lab_results=lab_results,
|
|
|
|
| 272 |
)}"
|
| 273 |
|
| 274 |
try:
|
| 275 |
+
response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
parsed = safe_json_parse(response)
|
| 277 |
if parsed:
|
| 278 |
state["pharmacology_notes"] = json.dumps(parsed, indent=2)
|
| 279 |
|
|
|
|
| 280 |
primary = parsed.get("primary_recommendation", {})
|
| 281 |
recommendation = {
|
| 282 |
"primary_antibiotic": primary.get("antibiotic"),
|
|
|
|
| 289 |
"safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])],
|
| 290 |
}
|
| 291 |
|
|
|
|
| 292 |
alt = parsed.get("alternative_recommendation", {})
|
| 293 |
if alt.get("antibiotic"):
|
| 294 |
recommendation["backup_antibiotic"] = alt.get("antibiotic")
|
| 295 |
|
| 296 |
state["recommendation"] = recommendation
|
| 297 |
|
|
|
|
| 298 |
for alert in parsed.get("safety_alerts", []):
|
| 299 |
if alert.get("level") in ["WARNING", "CRITICAL"]:
|
| 300 |
state.setdefault("safety_warnings", []).append(alert.get("message"))
|
| 301 |
|
|
|
|
| 302 |
if primary.get("antibiotic"):
|
| 303 |
safety_result = _run_txgemma_safety_check(
|
| 304 |
antibiotic=primary.get("antibiotic"),
|
|
|
|
| 311 |
)
|
| 312 |
if safety_result:
|
| 313 |
state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}")
|
|
|
|
| 314 |
else:
|
| 315 |
state["pharmacology_notes"] = response
|
| 316 |
state["recommendation"] = {"rationale": response}
|
|
|
|
| 324 |
return state
|
| 325 |
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
def _format_patient_data(state: InfectionState) -> str:
|
| 328 |
+
"""Format patient fields from state into a readable string for prompt injection."""
|
| 329 |
lines = []
|
| 330 |
|
| 331 |
if state.get("patient_id"):
|
|
|
|
| 370 |
crcl: Optional[float],
|
| 371 |
medications: list,
|
| 372 |
) -> Optional[str]:
|
| 373 |
+
"""Run a supplementary TxGemma toxicology check on the proposed prescription."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
try:
|
| 375 |
prompt = TXGEMMA_SAFETY_PROMPT.format(
|
| 376 |
antibiotic=antibiotic,
|
|
|
|
| 381 |
crcl=crcl or "Unknown",
|
| 382 |
medications=", ".join(medications) if medications else "None",
|
| 383 |
)
|
| 384 |
+
# Agent 4 safety check uses TxGemma 9B; on limited GPU the env var maps this to 2B
|
| 385 |
+
return run_inference(prompt=prompt, model_name="txgemma_9b", max_new_tokens=256, temperature=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
except Exception as e:
|
| 387 |
logger.warning(f"TxGemma safety check failed: {e}")
|
| 388 |
return None
|
| 389 |
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
AGENTS = {
|
| 392 |
"intake_historian": run_intake_historian,
|
| 393 |
"vision_specialist": run_vision_specialist,
|
|
|
|
| 397 |
|
| 398 |
|
| 399 |
def run_agent(agent_name: str, state: InfectionState) -> InfectionState:
|
| 400 |
+
"""Dispatch to a named agent."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
if agent_name not in AGENTS:
|
| 402 |
raise ValueError(f"Unknown agent: {agent_name}")
|
|
|
|
| 403 |
return AGENTS[agent_name](state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
import os
|
| 5 |
from functools import lru_cache
|
| 6 |
from pathlib import Path
|
|
@@ -9,104 +7,63 @@ from typing import Literal, Optional
|
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
|
| 12 |
-
|
| 13 |
-
# Load variables from a local .env if present (handy for local dev)
|
| 14 |
load_dotenv()
|
| 15 |
|
| 16 |
|
| 17 |
class Settings(BaseModel):
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
# ------------------------------------------------------------------
|
| 26 |
-
# General environment
|
| 27 |
-
# ------------------------------------------------------------------
|
| 28 |
environment: Literal["local", "kaggle", "production"] = Field(
|
| 29 |
default_factory=lambda: os.getenv("MEDIC_ENV", "local")
|
| 30 |
)
|
| 31 |
-
|
| 32 |
project_root: Path = Field(
|
| 33 |
default_factory=lambda: Path(__file__).resolve().parents[1]
|
| 34 |
)
|
| 35 |
-
|
| 36 |
data_dir: Path = Field(
|
| 37 |
-
default_factory=lambda: Path(
|
| 38 |
-
os.getenv("MEDIC_DATA_DIR", "data")
|
| 39 |
-
)
|
| 40 |
)
|
| 41 |
-
|
| 42 |
chroma_db_dir: Path = Field(
|
| 43 |
-
default_factory=lambda: Path(
|
| 44 |
-
os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db")
|
| 45 |
-
)
|
| 46 |
)
|
| 47 |
|
| 48 |
-
# ------------------------------------------------------------------
|
| 49 |
-
# Model + deployment preferences
|
| 50 |
-
# ------------------------------------------------------------------
|
| 51 |
default_backend: Literal["vertex", "local"] = Field(
|
| 52 |
default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "vertex") # type: ignore[arg-type]
|
| 53 |
)
|
| 54 |
-
|
| 55 |
-
# Quantization mode for local models
|
| 56 |
quantization: Literal["none", "4bit"] = Field(
|
| 57 |
default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
|
| 58 |
)
|
| 59 |
-
|
| 60 |
-
# Embedding model used for ChromaDB / RAG
|
| 61 |
embedding_model_name: str = Field(
|
| 62 |
-
default_factory=lambda: os.getenv(
|
| 63 |
-
"MEDIC_EMBEDDING_MODEL",
|
| 64 |
-
"sentence-transformers/all-MiniLM-L6-v2",
|
| 65 |
-
)
|
| 66 |
)
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
# Vertex AI configuration (MedGemma / TxGemma hosted on Vertex)
|
| 70 |
-
# ------------------------------------------------------------------
|
| 71 |
use_vertex: bool = Field(
|
| 72 |
-
default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower()
|
| 73 |
-
in {"1", "true", "yes"}
|
| 74 |
)
|
| 75 |
-
|
| 76 |
vertex_project_id: Optional[str] = Field(
|
| 77 |
default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
|
| 78 |
)
|
| 79 |
vertex_location: str = Field(
|
| 80 |
default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
|
| 81 |
)
|
| 82 |
-
|
| 83 |
-
# Model IDs as expected by Vertex / langchain-google-vertexai
|
| 84 |
vertex_medgemma_4b_model: str = Field(
|
| 85 |
-
default_factory=lambda: os.getenv(
|
| 86 |
-
"MEDIC_VERTEX_MEDGEMMA_4B_MODEL",
|
| 87 |
-
"med-gemma-4b-it",
|
| 88 |
-
)
|
| 89 |
)
|
| 90 |
vertex_medgemma_27b_model: str = Field(
|
| 91 |
-
default_factory=lambda: os.getenv(
|
| 92 |
-
"MEDIC_VERTEX_MEDGEMMA_27B_MODEL",
|
| 93 |
-
"med-gemma-27b-text-it",
|
| 94 |
-
)
|
| 95 |
)
|
| 96 |
vertex_txgemma_9b_model: str = Field(
|
| 97 |
-
default_factory=lambda: os.getenv(
|
| 98 |
-
"MEDIC_VERTEX_TXGEMMA_9B_MODEL",
|
| 99 |
-
"tx-gemma-9b",
|
| 100 |
-
)
|
| 101 |
)
|
| 102 |
vertex_txgemma_2b_model: str = Field(
|
| 103 |
-
default_factory=lambda: os.getenv(
|
| 104 |
-
"MEDIC_VERTEX_TXGEMMA_2B_MODEL",
|
| 105 |
-
"tx-gemma-2b",
|
| 106 |
-
)
|
| 107 |
)
|
| 108 |
-
|
| 109 |
-
# Standard GOOGLE_APPLICATION_CREDENTIALS path, if needed
|
| 110 |
google_application_credentials: Optional[Path] = Field(
|
| 111 |
default_factory=lambda: (
|
| 112 |
Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
|
|
@@ -115,9 +72,7 @@ class Settings(BaseModel):
|
|
| 115 |
)
|
| 116 |
)
|
| 117 |
|
| 118 |
-
#
|
| 119 |
-
# Local model paths (for offline / Kaggle GPU usage)
|
| 120 |
-
# ------------------------------------------------------------------
|
| 121 |
local_medgemma_4b_model: Optional[str] = Field(
|
| 122 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
|
| 123 |
)
|
|
@@ -134,17 +89,6 @@ class Settings(BaseModel):
|
|
| 134 |
|
| 135 |
@lru_cache(maxsize=1)
|
| 136 |
def get_settings() -> Settings:
|
| 137 |
-
"""
|
| 138 |
-
Return a cached Settings instance.
|
| 139 |
-
|
| 140 |
-
Use this helper everywhere instead of instantiating Settings directly:
|
| 141 |
-
|
| 142 |
-
from src.config import get_settings
|
| 143 |
-
settings = get_settings()
|
| 144 |
-
"""
|
| 145 |
-
|
| 146 |
return Settings()
|
| 147 |
|
| 148 |
-
|
| 149 |
-
__all__ = ["Settings", "get_settings"]
|
| 150 |
-
|
|
|
|
| 1 |
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
from functools import lru_cache
|
| 4 |
from pathlib import Path
|
|
|
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
|
|
|
|
|
|
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
|
| 13 |
class Settings(BaseModel):
|
| 14 |
"""
|
| 15 |
+
All configuration for Med-I-C, read from environment variables.
|
| 16 |
|
| 17 |
+
Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
|
| 18 |
+
Backend selection (vertex or local) is controlled by MEDIC_DEFAULT_BACKEND.
|
| 19 |
"""
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
environment: Literal["local", "kaggle", "production"] = Field(
|
| 22 |
default_factory=lambda: os.getenv("MEDIC_ENV", "local")
|
| 23 |
)
|
|
|
|
| 24 |
project_root: Path = Field(
|
| 25 |
default_factory=lambda: Path(__file__).resolve().parents[1]
|
| 26 |
)
|
|
|
|
| 27 |
data_dir: Path = Field(
|
| 28 |
+
default_factory=lambda: Path(os.getenv("MEDIC_DATA_DIR", "data"))
|
|
|
|
|
|
|
| 29 |
)
|
|
|
|
| 30 |
chroma_db_dir: Path = Field(
|
| 31 |
+
default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
default_backend: Literal["vertex", "local"] = Field(
|
| 35 |
default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "vertex") # type: ignore[arg-type]
|
| 36 |
)
|
| 37 |
+
# 4-bit quantization via bitsandbytes (local backend only)
|
|
|
|
| 38 |
quantization: Literal["none", "4bit"] = Field(
|
| 39 |
default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
|
| 40 |
)
|
|
|
|
|
|
|
| 41 |
embedding_model_name: str = Field(
|
| 42 |
+
default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
|
|
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
+
# Vertex AI settings
|
|
|
|
|
|
|
| 46 |
use_vertex: bool = Field(
|
| 47 |
+
default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower() in {"1", "true", "yes"}
|
|
|
|
| 48 |
)
|
|
|
|
| 49 |
vertex_project_id: Optional[str] = Field(
|
| 50 |
default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
|
| 51 |
)
|
| 52 |
vertex_location: str = Field(
|
| 53 |
default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
|
| 54 |
)
|
|
|
|
|
|
|
| 55 |
vertex_medgemma_4b_model: str = Field(
|
| 56 |
+
default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_4B_MODEL", "med-gemma-4b-it")
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
vertex_medgemma_27b_model: str = Field(
|
| 59 |
+
default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_27B_MODEL", "med-gemma-27b-text-it")
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
vertex_txgemma_9b_model: str = Field(
|
| 62 |
+
default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_9B_MODEL", "tx-gemma-9b")
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
vertex_txgemma_2b_model: str = Field(
|
| 65 |
+
default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_2B_MODEL", "tx-gemma-2b")
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
|
|
|
|
|
|
| 67 |
google_application_credentials: Optional[Path] = Field(
|
| 68 |
default_factory=lambda: (
|
| 69 |
Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
|
|
|
|
| 72 |
)
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Local HuggingFace model paths (used when MEDIC_DEFAULT_BACKEND=local)
|
|
|
|
|
|
|
| 76 |
local_medgemma_4b_model: Optional[str] = Field(
|
| 77 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
|
| 78 |
)
|
|
|
|
| 89 |
|
| 90 |
@lru_cache(maxsize=1)
|
| 91 |
def get_settings() -> Settings:
|
| 92 |
+
"""Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
return Settings()
|
| 94 |
|
|
|
|
|
|
|
|
|
src/db/import_data.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""Data import scripts for Med-I-C structured documents."""
|
| 2 |
|
| 3 |
import pandas as pd
|
| 4 |
-
import re
|
| 5 |
from pathlib import Path
|
| 6 |
from .database import (
|
| 7 |
get_connection, init_database, execute_many,
|
|
@@ -10,7 +9,7 @@ from .database import (
|
|
| 10 |
|
| 11 |
|
| 12 |
def safe_float(value):
|
| 13 |
-
"""
|
| 14 |
if pd.isna(value):
|
| 15 |
return None
|
| 16 |
try:
|
|
@@ -20,7 +19,7 @@ def safe_float(value):
|
|
| 20 |
|
| 21 |
|
| 22 |
def safe_int(value):
|
| 23 |
-
"""
|
| 24 |
if pd.isna(value):
|
| 25 |
return None
|
| 26 |
try:
|
|
@@ -29,41 +28,46 @@ def safe_int(value):
|
|
| 29 |
return None
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def classify_severity(description: str) -> str:
|
| 33 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if not description:
|
| 35 |
return "unknown"
|
| 36 |
|
| 37 |
desc_lower = description.lower()
|
| 38 |
|
| 39 |
-
# Major severity indicators
|
| 40 |
major_keywords = [
|
| 41 |
"cardiotoxic", "nephrotoxic", "hepatotoxic", "neurotoxic",
|
| 42 |
"fatal", "death", "severe", "contraindicated", "arrhythmia",
|
| 43 |
"qt prolongation", "seizure", "bleeding", "hemorrhage",
|
| 44 |
-
"serotonin syndrome", "neuroleptic malignant"
|
| 45 |
]
|
| 46 |
-
|
| 47 |
-
# Moderate severity indicators
|
| 48 |
moderate_keywords = [
|
| 49 |
"increase", "decrease", "reduce", "enhance", "inhibit",
|
| 50 |
"metabolism", "concentration", "absorption", "excretion",
|
| 51 |
-
"therapeutic effect", "adverse effect", "toxicity"
|
| 52 |
]
|
| 53 |
|
| 54 |
-
for
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
for keyword in moderate_keywords:
|
| 59 |
-
if keyword in desc_lower:
|
| 60 |
-
return "moderate"
|
| 61 |
-
|
| 62 |
return "minor"
|
| 63 |
|
| 64 |
|
| 65 |
def import_eml_antibiotics() -> int:
|
| 66 |
-
"""Import WHO EML antibiotic classification data."""
|
| 67 |
print("Importing EML antibiotic data...")
|
| 68 |
|
| 69 |
eml_files = {
|
|
@@ -79,29 +83,21 @@ def import_eml_antibiotics() -> int:
|
|
| 79 |
continue
|
| 80 |
|
| 81 |
try:
|
| 82 |
-
# Use openpyxl directly with read_only=True for faster loading
|
| 83 |
import openpyxl
|
| 84 |
wb = openpyxl.load_workbook(filepath, read_only=True)
|
| 85 |
ws = wb.active
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
-
for row_idx, row in enumerate(ws.iter_rows(min_row=2, values_only=True), start=2):
|
| 94 |
row_dict = dict(zip(headers, row))
|
| 95 |
-
|
| 96 |
medicine = str(row_dict.get('medicine_name', row_dict.get('medicine', '')))
|
| 97 |
-
if not medicine or medicine
|
| 98 |
continue
|
| 99 |
|
| 100 |
-
def safe_str(val):
|
| 101 |
-
if val is None or pd.isna(val):
|
| 102 |
-
return ''
|
| 103 |
-
return str(val)
|
| 104 |
-
|
| 105 |
records.append((
|
| 106 |
medicine,
|
| 107 |
category,
|
|
@@ -114,20 +110,20 @@ def import_eml_antibiotics() -> int:
|
|
| 114 |
))
|
| 115 |
|
| 116 |
wb.close()
|
| 117 |
-
print(f" Loaded {
|
| 118 |
|
| 119 |
except Exception as e:
|
| 120 |
print(f" Warning: Error reading {filepath}: {e}")
|
| 121 |
continue
|
| 122 |
|
| 123 |
if records:
|
| 124 |
-
|
| 125 |
-
INSERT INTO eml_antibiotics
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
print(f" Imported {len(records)} EML antibiotic records total")
|
| 132 |
|
| 133 |
return len(records)
|
|
@@ -143,95 +139,78 @@ def import_atlas_susceptibility() -> int:
|
|
| 143 |
print(f" Warning: {filepath} not found, skipping...")
|
| 144 |
return 0
|
| 145 |
|
| 146 |
-
# Read the raw data to find the header row and extract region
|
| 147 |
df_raw = pd.read_excel(filepath, sheet_name="Percent", header=None)
|
| 148 |
|
| 149 |
-
#
|
| 150 |
region = "Unknown"
|
| 151 |
-
for
|
| 152 |
cell = str(row.iloc[0]) if pd.notna(row.iloc[0]) else ""
|
| 153 |
if "from" in cell.lower():
|
| 154 |
-
# Extract country from "Percentage Susceptibility from Argentina"
|
| 155 |
parts = cell.split("from")
|
| 156 |
if len(parts) > 1:
|
| 157 |
region = parts[1].strip()
|
| 158 |
break
|
| 159 |
|
| 160 |
-
#
|
| 161 |
-
header_row = 4
|
| 162 |
for idx, row in df_raw.head(10).iterrows():
|
| 163 |
if any('Antibacterial' in str(v) for v in row.values if pd.notna(v)):
|
| 164 |
header_row = idx
|
| 165 |
break
|
| 166 |
|
| 167 |
-
# Read with proper header
|
| 168 |
df = pd.read_excel(filepath, sheet_name="Percent", header=header_row)
|
| 169 |
-
|
| 170 |
-
# Standardize column names
|
| 171 |
df.columns = [str(col).strip().lower().replace(' ', '_').replace('.', '') for col in df.columns]
|
| 172 |
|
| 173 |
records = []
|
| 174 |
for _, row in df.iterrows():
|
| 175 |
antibiotic = str(row.get('antibacterial', ''))
|
| 176 |
-
|
| 177 |
-
# Skip empty or non-antibiotic rows
|
| 178 |
if not antibiotic or antibiotic == 'nan' or 'omitted' in antibiotic.lower():
|
| 179 |
continue
|
| 180 |
if 'in vitro' in antibiotic.lower() or 'table cells' in antibiotic.lower():
|
| 181 |
continue
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
pct_s = row.get('susc', row.get('susceptible', None))
|
| 186 |
-
pct_i = row.get('int', row.get('intermediate', None))
|
| 187 |
-
pct_r = row.get('res', row.get('resistant', None))
|
| 188 |
-
|
| 189 |
-
# Use safe conversion functions
|
| 190 |
-
n_int = safe_int(n_value)
|
| 191 |
-
s_float = safe_float(pct_s)
|
| 192 |
|
| 193 |
if n_int is not None and s_float is not None:
|
| 194 |
records.append((
|
| 195 |
-
"General",
|
| 196 |
-
"",
|
| 197 |
antibiotic,
|
| 198 |
s_float,
|
| 199 |
-
safe_float(
|
| 200 |
-
safe_float(
|
| 201 |
n_int,
|
| 202 |
-
2024,
|
| 203 |
region,
|
| 204 |
-
"ATLAS"
|
| 205 |
))
|
| 206 |
|
| 207 |
if records:
|
| 208 |
-
|
| 209 |
-
INSERT INTO atlas_susceptibility
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
print(f" Imported {len(records)} ATLAS susceptibility records from {region}")
|
| 217 |
|
| 218 |
return len(records)
|
| 219 |
|
| 220 |
|
| 221 |
def import_mic_breakpoints() -> int:
|
| 222 |
-
"""Import EUCAST MIC breakpoint tables."""
|
| 223 |
print("Importing MIC breakpoint data...")
|
| 224 |
|
| 225 |
filepath = DOCS_DIR / "mic_breakpoints" / "v_16.0__BreakpointTables.xlsx"
|
| 226 |
-
|
| 227 |
if not filepath.exists():
|
| 228 |
print(f" Warning: {filepath} not found, skipping...")
|
| 229 |
return 0
|
| 230 |
|
| 231 |
-
# Get all sheet names
|
| 232 |
xl = pd.ExcelFile(filepath)
|
| 233 |
-
|
| 234 |
-
# Skip non-pathogen sheets
|
| 235 |
skip_sheets = {'Content', 'Changes', 'Notes', 'Guidance', 'Dosages',
|
| 236 |
'Technical uncertainty', 'PK PD breakpoints', 'PK PD cutoffs'}
|
| 237 |
|
|
@@ -239,58 +218,48 @@ def import_mic_breakpoints() -> int:
|
|
| 239 |
for sheet_name in xl.sheet_names:
|
| 240 |
if sheet_name in skip_sheets:
|
| 241 |
continue
|
| 242 |
-
|
| 243 |
try:
|
| 244 |
df = pd.read_excel(filepath, sheet_name=sheet_name, header=None)
|
| 245 |
-
|
| 246 |
-
# Try to find antibiotic data - look for rows with MIC values
|
| 247 |
-
pathogen_group = sheet_name
|
| 248 |
-
|
| 249 |
-
# Simple heuristic: look for rows that might contain antibiotic names and MIC values
|
| 250 |
-
for idx, row in df.iterrows():
|
| 251 |
row_values = [str(v).strip() for v in row.values if pd.notna(v)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
records.append((
|
| 272 |
-
pathogen_group,
|
| 273 |
-
potential_antibiotic,
|
| 274 |
-
None, # route
|
| 275 |
-
mic_values[0] if len(mic_values) > 0 else None, # S breakpoint
|
| 276 |
-
mic_values[1] if len(mic_values) > 1 else None, # R breakpoint
|
| 277 |
-
None, # disk S
|
| 278 |
-
None, # disk R
|
| 279 |
-
None, # notes
|
| 280 |
-
"16.0"
|
| 281 |
-
))
|
| 282 |
except Exception as e:
|
| 283 |
print(f" Warning: Could not parse sheet '{sheet_name}': {e}")
|
| 284 |
continue
|
| 285 |
|
| 286 |
if records:
|
| 287 |
-
|
| 288 |
-
INSERT INTO mic_breakpoints
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
print(f" Imported {len(records)} MIC breakpoint records")
|
| 295 |
|
| 296 |
return len(records)
|
|
@@ -303,36 +272,32 @@ INTERACTIONS_CSV = DOCS_DIR / "drug_safety" / "db_drug_interactions.csv"
|
|
| 303 |
|
| 304 |
def _resolve_interactions_csv() -> Path | None:
|
| 305 |
"""
|
| 306 |
-
|
| 307 |
|
| 308 |
-
|
| 309 |
-
1. docs/drug_safety/db_drug_interactions.csv
|
| 310 |
-
2. /kaggle/input/drug-drug-interactions/
|
| 311 |
-
3. Kaggle API download
|
| 312 |
"""
|
| 313 |
-
# 1. Already present
|
| 314 |
if INTERACTIONS_CSV.exists():
|
| 315 |
return INTERACTIONS_CSV
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
|
| 322 |
-
# 3. Download via Kaggle API
|
| 323 |
print(f" CSV not found — downloading from Kaggle dataset '{KAGGLE_DATASET}' ...")
|
| 324 |
try:
|
| 325 |
-
import kaggle # noqa: F401
|
|
|
|
| 326 |
dest = INTERACTIONS_CSV.parent
|
| 327 |
dest.mkdir(parents=True, exist_ok=True)
|
| 328 |
-
import subprocess
|
| 329 |
result = subprocess.run(
|
| 330 |
-
["kaggle", "datasets", "download", "-d", KAGGLE_DATASET,
|
| 331 |
-
"--unzip", "-p", str(dest)],
|
| 332 |
capture_output=True, text=True,
|
| 333 |
)
|
| 334 |
if result.returncode == 0:
|
| 335 |
-
# Find the downloaded CSV
|
| 336 |
for f in dest.glob("*.csv"):
|
| 337 |
print(f" Downloaded: {f.name}")
|
| 338 |
return f
|
|
@@ -347,23 +312,18 @@ def _resolve_interactions_csv() -> Path | None:
|
|
| 347 |
|
| 348 |
|
| 349 |
def import_drug_interactions(limit: int = None) -> int:
|
| 350 |
-
"""Import drug-drug
|
| 351 |
print("Importing drug interactions data...")
|
| 352 |
|
| 353 |
filepath = _resolve_interactions_csv()
|
| 354 |
-
|
| 355 |
if filepath is None:
|
| 356 |
print(" Skipping drug interactions — CSV unavailable.")
|
| 357 |
print(f" To fix: attach the Kaggle dataset '{KAGGLE_DATASET}' to your notebook,")
|
| 358 |
print(" or set up ~/.kaggle/kaggle.json for API access.")
|
| 359 |
return 0
|
| 360 |
|
| 361 |
-
# Read CSV in chunks due to large size
|
| 362 |
-
chunk_size = 10000
|
| 363 |
total_records = 0
|
| 364 |
-
|
| 365 |
-
for chunk in pd.read_csv(filepath, chunksize=chunk_size):
|
| 366 |
-
# Standardize column names
|
| 367 |
chunk.columns = [col.strip().lower().replace(' ', '_') for col in chunk.columns]
|
| 368 |
|
| 369 |
records = []
|
|
@@ -372,19 +332,14 @@ def import_drug_interactions(limit: int = None) -> int:
|
|
| 372 |
drug_2 = str(row.get('drug_2', row.get('drug2', row.iloc[1] if len(row) > 1 else '')))
|
| 373 |
description = str(row.get('interaction_description', row.get('description',
|
| 374 |
row.get('interaction', row.iloc[2] if len(row) > 2 else ''))))
|
| 375 |
-
|
| 376 |
-
severity = classify_severity(description)
|
| 377 |
-
|
| 378 |
if drug_1 and drug_2:
|
| 379 |
-
records.append((drug_1, drug_2, description,
|
| 380 |
|
| 381 |
if records:
|
| 382 |
-
|
| 383 |
-
INSERT INTO drug_interactions
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
"""
|
| 387 |
-
execute_many(query, records)
|
| 388 |
total_records += len(records)
|
| 389 |
|
| 390 |
if limit and total_records >= limit:
|
|
@@ -395,24 +350,19 @@ def import_drug_interactions(limit: int = None) -> int:
|
|
| 395 |
|
| 396 |
|
| 397 |
def import_all_data(interactions_limit: int = None) -> dict:
|
| 398 |
-
"""
|
| 399 |
print(f"\n{'='*50}")
|
| 400 |
print("Med-I-C Data Import")
|
| 401 |
print(f"{'='*50}\n")
|
| 402 |
|
| 403 |
-
# Initialize database
|
| 404 |
init_database()
|
| 405 |
|
| 406 |
-
# Clear existing data
|
| 407 |
with get_connection() as conn:
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
conn.execute("DELETE FROM mic_breakpoints")
|
| 411 |
-
conn.execute("DELETE FROM drug_interactions")
|
| 412 |
conn.commit()
|
| 413 |
print("Cleared existing data\n")
|
| 414 |
|
| 415 |
-
# Import all data
|
| 416 |
results = {
|
| 417 |
"eml_antibiotics": import_eml_antibiotics(),
|
| 418 |
"atlas_susceptibility": import_atlas_susceptibility(),
|
|
@@ -430,5 +380,4 @@ def import_all_data(interactions_limit: int = None) -> dict:
|
|
| 430 |
|
| 431 |
|
| 432 |
if __name__ == "__main__":
|
| 433 |
-
# Import with a limit on interactions for faster demo
|
| 434 |
import_all_data(interactions_limit=50000)
|
|
|
|
| 1 |
"""Data import scripts for Med-I-C structured documents."""
|
| 2 |
|
| 3 |
import pandas as pd
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from .database import (
|
| 6 |
get_connection, init_database, execute_many,
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def safe_float(value):
|
| 12 |
+
"""Convert value to float; return None if the value is NaN or non-numeric."""
|
| 13 |
if pd.isna(value):
|
| 14 |
return None
|
| 15 |
try:
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def safe_int(value):
|
| 22 |
+
"""Convert value to int via float; return None if the value is NaN or non-numeric."""
|
| 23 |
if pd.isna(value):
|
| 24 |
return None
|
| 25 |
try:
|
|
|
|
| 28 |
return None
|
| 29 |
|
| 30 |
|
| 31 |
+
def safe_str(value) -> str:
|
| 32 |
+
"""Convert value to string; return empty string for None or NaN."""
|
| 33 |
+
if value is None or pd.isna(value):
|
| 34 |
+
return ''
|
| 35 |
+
return str(value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
def classify_severity(description: str) -> str:
|
| 39 |
+
"""
|
| 40 |
+
Classify drug interaction severity from the interaction description text.
|
| 41 |
+
|
| 42 |
+
Returns 'major', 'moderate', or 'minor' based on keyword presence.
|
| 43 |
+
Major keywords take precedence over moderate.
|
| 44 |
+
"""
|
| 45 |
if not description:
|
| 46 |
return "unknown"
|
| 47 |
|
| 48 |
desc_lower = description.lower()
|
| 49 |
|
|
|
|
| 50 |
major_keywords = [
|
| 51 |
"cardiotoxic", "nephrotoxic", "hepatotoxic", "neurotoxic",
|
| 52 |
"fatal", "death", "severe", "contraindicated", "arrhythmia",
|
| 53 |
"qt prolongation", "seizure", "bleeding", "hemorrhage",
|
| 54 |
+
"serotonin syndrome", "neuroleptic malignant",
|
| 55 |
]
|
|
|
|
|
|
|
| 56 |
moderate_keywords = [
|
| 57 |
"increase", "decrease", "reduce", "enhance", "inhibit",
|
| 58 |
"metabolism", "concentration", "absorption", "excretion",
|
| 59 |
+
"therapeutic effect", "adverse effect", "toxicity",
|
| 60 |
]
|
| 61 |
|
| 62 |
+
if any(kw in desc_lower for kw in major_keywords):
|
| 63 |
+
return "major"
|
| 64 |
+
if any(kw in desc_lower for kw in moderate_keywords):
|
| 65 |
+
return "moderate"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return "minor"
|
| 67 |
|
| 68 |
|
| 69 |
def import_eml_antibiotics() -> int:
|
| 70 |
+
"""Import WHO EML antibiotic classification data from the three AWaRe Excel files."""
|
| 71 |
print("Importing EML antibiotic data...")
|
| 72 |
|
| 73 |
eml_files = {
|
|
|
|
| 83 |
continue
|
| 84 |
|
| 85 |
try:
|
|
|
|
| 86 |
import openpyxl
|
| 87 |
wb = openpyxl.load_workbook(filepath, read_only=True)
|
| 88 |
ws = wb.active
|
| 89 |
|
| 90 |
+
headers = [
|
| 91 |
+
str(cell.value).strip().lower().replace(' ', '_') if cell.value else f'col_{i}'
|
| 92 |
+
for i, cell in enumerate(ws[1])
|
| 93 |
+
]
|
| 94 |
|
| 95 |
+
for row in ws.iter_rows(min_row=2, values_only=True):
|
|
|
|
| 96 |
row_dict = dict(zip(headers, row))
|
|
|
|
| 97 |
medicine = str(row_dict.get('medicine_name', row_dict.get('medicine', '')))
|
| 98 |
+
if not medicine or medicine in ('None', 'nan'):
|
| 99 |
continue
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
records.append((
|
| 102 |
medicine,
|
| 103 |
category,
|
|
|
|
| 110 |
))
|
| 111 |
|
| 112 |
wb.close()
|
| 113 |
+
print(f" Loaded {sum(1 for r in records if r[1] == category)} from {category}")
|
| 114 |
|
| 115 |
except Exception as e:
|
| 116 |
print(f" Warning: Error reading {filepath}: {e}")
|
| 117 |
continue
|
| 118 |
|
| 119 |
if records:
|
| 120 |
+
execute_many(
|
| 121 |
+
"""INSERT INTO eml_antibiotics
|
| 122 |
+
(medicine_name, who_category, eml_section, formulations,
|
| 123 |
+
indication, atc_codes, combined_with, status)
|
| 124 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
| 125 |
+
records,
|
| 126 |
+
)
|
| 127 |
print(f" Imported {len(records)} EML antibiotic records total")
|
| 128 |
|
| 129 |
return len(records)
|
|
|
|
| 139 |
print(f" Warning: {filepath} not found, skipping...")
|
| 140 |
return 0
|
| 141 |
|
|
|
|
| 142 |
df_raw = pd.read_excel(filepath, sheet_name="Percent", header=None)
|
| 143 |
|
| 144 |
+
# Title row contains "Percentage Susceptibility from <Country>"
|
| 145 |
region = "Unknown"
|
| 146 |
+
for _, row in df_raw.head(5).iterrows():
|
| 147 |
cell = str(row.iloc[0]) if pd.notna(row.iloc[0]) else ""
|
| 148 |
if "from" in cell.lower():
|
|
|
|
| 149 |
parts = cell.split("from")
|
| 150 |
if len(parts) > 1:
|
| 151 |
region = parts[1].strip()
|
| 152 |
break
|
| 153 |
|
| 154 |
+
# Locate the actual header row by finding "Antibacterial"
|
| 155 |
+
header_row = 4
|
| 156 |
for idx, row in df_raw.head(10).iterrows():
|
| 157 |
if any('Antibacterial' in str(v) for v in row.values if pd.notna(v)):
|
| 158 |
header_row = idx
|
| 159 |
break
|
| 160 |
|
|
|
|
| 161 |
df = pd.read_excel(filepath, sheet_name="Percent", header=header_row)
|
|
|
|
|
|
|
| 162 |
df.columns = [str(col).strip().lower().replace(' ', '_').replace('.', '') for col in df.columns]
|
| 163 |
|
| 164 |
records = []
|
| 165 |
for _, row in df.iterrows():
|
| 166 |
antibiotic = str(row.get('antibacterial', ''))
|
|
|
|
|
|
|
| 167 |
if not antibiotic or antibiotic == 'nan' or 'omitted' in antibiotic.lower():
|
| 168 |
continue
|
| 169 |
if 'in vitro' in antibiotic.lower() or 'table cells' in antibiotic.lower():
|
| 170 |
continue
|
| 171 |
|
| 172 |
+
n_int = safe_int(row.get('n'))
|
| 173 |
+
s_float = safe_float(row.get('susc', row.get('susceptible')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
if n_int is not None and s_float is not None:
|
| 176 |
records.append((
|
| 177 |
+
"General",
|
| 178 |
+
"",
|
| 179 |
antibiotic,
|
| 180 |
s_float,
|
| 181 |
+
safe_float(row.get('int', row.get('intermediate'))),
|
| 182 |
+
safe_float(row.get('res', row.get('resistant'))),
|
| 183 |
n_int,
|
| 184 |
+
2024,
|
| 185 |
region,
|
| 186 |
+
"ATLAS",
|
| 187 |
))
|
| 188 |
|
| 189 |
if records:
|
| 190 |
+
execute_many(
|
| 191 |
+
"""INSERT INTO atlas_susceptibility
|
| 192 |
+
(species, family, antibiotic, percent_susceptible,
|
| 193 |
+
percent_intermediate, percent_resistant, total_isolates,
|
| 194 |
+
year, region, source)
|
| 195 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
| 196 |
+
records,
|
| 197 |
+
)
|
| 198 |
print(f" Imported {len(records)} ATLAS susceptibility records from {region}")
|
| 199 |
|
| 200 |
return len(records)
|
| 201 |
|
| 202 |
|
| 203 |
def import_mic_breakpoints() -> int:
|
| 204 |
+
"""Import EUCAST MIC breakpoint tables from the Excel file."""
|
| 205 |
print("Importing MIC breakpoint data...")
|
| 206 |
|
| 207 |
filepath = DOCS_DIR / "mic_breakpoints" / "v_16.0__BreakpointTables.xlsx"
|
|
|
|
| 208 |
if not filepath.exists():
|
| 209 |
print(f" Warning: {filepath} not found, skipping...")
|
| 210 |
return 0
|
| 211 |
|
|
|
|
| 212 |
xl = pd.ExcelFile(filepath)
|
| 213 |
+
# These sheets contain metadata/guidance, not pathogen-specific breakpoints
|
|
|
|
| 214 |
skip_sheets = {'Content', 'Changes', 'Notes', 'Guidance', 'Dosages',
|
| 215 |
'Technical uncertainty', 'PK PD breakpoints', 'PK PD cutoffs'}
|
| 216 |
|
|
|
|
| 218 |
for sheet_name in xl.sheet_names:
|
| 219 |
if sheet_name in skip_sheets:
|
| 220 |
continue
|
|
|
|
| 221 |
try:
|
| 222 |
df = pd.read_excel(filepath, sheet_name=sheet_name, header=None)
|
| 223 |
+
for _, row in df.iterrows():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
row_values = [str(v).strip() for v in row.values if pd.notna(v)]
|
| 225 |
+
if len(row_values) < 2:
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
potential_antibiotic = row_values[0]
|
| 229 |
+
if any(kw in potential_antibiotic.lower() for kw in
|
| 230 |
+
['antibiotic', 'agent', 'note', 'disk', 'mic', 'breakpoint']):
|
| 231 |
+
continue
|
| 232 |
|
| 233 |
+
# Extract numeric MIC values; strip inequality signs
|
| 234 |
+
mic_values = []
|
| 235 |
+
for v in row_values[1:]:
|
| 236 |
+
try:
|
| 237 |
+
mic_values.append(float(v.replace('≤', '').replace('>', '').replace('<', '').strip()))
|
| 238 |
+
except (ValueError, AttributeError):
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
if len(mic_values) >= 2 and len(potential_antibiotic) > 2:
|
| 242 |
+
records.append((
|
| 243 |
+
sheet_name, # pathogen_group
|
| 244 |
+
potential_antibiotic,
|
| 245 |
+
None, # route
|
| 246 |
+
mic_values[0], # S breakpoint
|
| 247 |
+
mic_values[1], # R breakpoint
|
| 248 |
+
None, None, None, # disk S, disk R, notes
|
| 249 |
+
"16.0",
|
| 250 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
except Exception as e:
|
| 252 |
print(f" Warning: Could not parse sheet '{sheet_name}': {e}")
|
| 253 |
continue
|
| 254 |
|
| 255 |
if records:
|
| 256 |
+
execute_many(
|
| 257 |
+
"""INSERT INTO mic_breakpoints
|
| 258 |
+
(pathogen_group, antibiotic, route, mic_susceptible, mic_resistant,
|
| 259 |
+
disk_susceptible, disk_resistant, notes, eucast_version)
|
| 260 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
| 261 |
+
records,
|
| 262 |
+
)
|
| 263 |
print(f" Imported {len(records)} MIC breakpoint records")
|
| 264 |
|
| 265 |
return len(records)
|
|
|
|
| 272 |
|
| 273 |
def _resolve_interactions_csv() -> Path | None:
|
| 274 |
"""
|
| 275 |
+
Find the drug interactions CSV file.
|
| 276 |
|
| 277 |
+
Checks in order:
|
| 278 |
+
1. docs/drug_safety/db_drug_interactions.csv (local)
|
| 279 |
+
2. /kaggle/input/drug-drug-interactions/ (Kaggle notebook with dataset attached)
|
| 280 |
+
3. Kaggle API download (requires ~/.kaggle/kaggle.json)
|
| 281 |
"""
|
|
|
|
| 282 |
if INTERACTIONS_CSV.exists():
|
| 283 |
return INTERACTIONS_CSV
|
| 284 |
|
| 285 |
+
if KAGGLE_INPUT_DIR.exists():
|
| 286 |
+
for candidate in KAGGLE_INPUT_DIR.glob("*.csv"):
|
| 287 |
+
print(f" Found CSV in Kaggle input: {candidate}")
|
| 288 |
+
return candidate
|
| 289 |
|
|
|
|
| 290 |
print(f" CSV not found — downloading from Kaggle dataset '{KAGGLE_DATASET}' ...")
|
| 291 |
try:
|
| 292 |
+
import kaggle # noqa: F401 — triggers credential check
|
| 293 |
+
import subprocess
|
| 294 |
dest = INTERACTIONS_CSV.parent
|
| 295 |
dest.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 296 |
result = subprocess.run(
|
| 297 |
+
["kaggle", "datasets", "download", "-d", KAGGLE_DATASET, "--unzip", "-p", str(dest)],
|
|
|
|
| 298 |
capture_output=True, text=True,
|
| 299 |
)
|
| 300 |
if result.returncode == 0:
|
|
|
|
| 301 |
for f in dest.glob("*.csv"):
|
| 302 |
print(f" Downloaded: {f.name}")
|
| 303 |
return f
|
|
|
|
| 312 |
|
| 313 |
|
| 314 |
def import_drug_interactions(limit: int = None) -> int:
|
| 315 |
+
"""Import drug-drug interactions from the DDInter CSV (Kaggle dataset mghobashy/drug-drug-interactions)."""
|
| 316 |
print("Importing drug interactions data...")
|
| 317 |
|
| 318 |
filepath = _resolve_interactions_csv()
|
|
|
|
| 319 |
if filepath is None:
|
| 320 |
print(" Skipping drug interactions — CSV unavailable.")
|
| 321 |
print(f" To fix: attach the Kaggle dataset '{KAGGLE_DATASET}' to your notebook,")
|
| 322 |
print(" or set up ~/.kaggle/kaggle.json for API access.")
|
| 323 |
return 0
|
| 324 |
|
|
|
|
|
|
|
| 325 |
total_records = 0
|
| 326 |
+
for chunk in pd.read_csv(filepath, chunksize=10000):
|
|
|
|
|
|
|
| 327 |
chunk.columns = [col.strip().lower().replace(' ', '_') for col in chunk.columns]
|
| 328 |
|
| 329 |
records = []
|
|
|
|
| 332 |
drug_2 = str(row.get('drug_2', row.get('drug2', row.iloc[1] if len(row) > 1 else '')))
|
| 333 |
description = str(row.get('interaction_description', row.get('description',
|
| 334 |
row.get('interaction', row.iloc[2] if len(row) > 2 else ''))))
|
|
|
|
|
|
|
|
|
|
| 335 |
if drug_1 and drug_2:
|
| 336 |
+
records.append((drug_1, drug_2, description, classify_severity(description)))
|
| 337 |
|
| 338 |
if records:
|
| 339 |
+
execute_many(
|
| 340 |
+
"INSERT INTO drug_interactions (drug_1, drug_2, interaction_description, severity) VALUES (?, ?, ?, ?)",
|
| 341 |
+
records,
|
| 342 |
+
)
|
|
|
|
|
|
|
| 343 |
total_records += len(records)
|
| 344 |
|
| 345 |
if limit and total_records >= limit:
|
|
|
|
| 350 |
|
| 351 |
|
| 352 |
def import_all_data(interactions_limit: int = None) -> dict:
|
| 353 |
+
"""Initialize the database and import all structured data sources."""
|
| 354 |
print(f"\n{'='*50}")
|
| 355 |
print("Med-I-C Data Import")
|
| 356 |
print(f"{'='*50}\n")
|
| 357 |
|
|
|
|
| 358 |
init_database()
|
| 359 |
|
|
|
|
| 360 |
with get_connection() as conn:
|
| 361 |
+
for table in ("eml_antibiotics", "atlas_susceptibility", "mic_breakpoints", "drug_interactions"):
|
| 362 |
+
conn.execute(f"DELETE FROM {table}")
|
|
|
|
|
|
|
| 363 |
conn.commit()
|
| 364 |
print("Cleared existing data\n")
|
| 365 |
|
|
|
|
| 366 |
results = {
|
| 367 |
"eml_antibiotics": import_eml_antibiotics(),
|
| 368 |
"atlas_susceptibility": import_atlas_susceptibility(),
|
|
|
|
| 380 |
|
| 381 |
|
| 382 |
if __name__ == "__main__":
|
|
|
|
| 383 |
import_all_data(interactions_limit=50000)
|
src/graph.py
CHANGED
|
@@ -1,17 +1,13 @@
|
|
| 1 |
"""
|
| 2 |
-
LangGraph
|
| 3 |
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
Stage
|
| 7 |
-
Intake Historian
|
| 8 |
-
|
| 9 |
-
Stage 2 (Targeted - lab results available):
|
| 10 |
-
Intake Historian -> Vision Specialist -> Trend Analyst -> Clinical Pharmacologist
|
| 11 |
"""
|
| 12 |
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
import logging
|
| 16 |
from typing import Literal
|
| 17 |
|
|
@@ -28,189 +24,59 @@ from .state import InfectionState
|
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
|
| 31 |
-
# =============================================================================
|
| 32 |
-
# NODE FUNCTIONS (Wrapper for agents)
|
| 33 |
-
# =============================================================================
|
| 34 |
-
|
| 35 |
-
def intake_historian_node(state: InfectionState) -> InfectionState:
|
| 36 |
-
"""Node 1: Run Intake Historian agent."""
|
| 37 |
-
logger.info("Graph: Executing Intake Historian node")
|
| 38 |
-
return run_intake_historian(state)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def vision_specialist_node(state: InfectionState) -> InfectionState:
|
| 42 |
-
"""Node 2: Run Vision Specialist agent."""
|
| 43 |
-
logger.info("Graph: Executing Vision Specialist node")
|
| 44 |
-
return run_vision_specialist(state)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def trend_analyst_node(state: InfectionState) -> InfectionState:
|
| 48 |
-
"""Node 3: Run Trend Analyst agent."""
|
| 49 |
-
logger.info("Graph: Executing Trend Analyst node")
|
| 50 |
-
return run_trend_analyst(state)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def clinical_pharmacologist_node(state: InfectionState) -> InfectionState:
|
| 54 |
-
"""Node 4: Run Clinical Pharmacologist agent."""
|
| 55 |
-
logger.info("Graph: Executing Clinical Pharmacologist node")
|
| 56 |
-
return run_clinical_pharmacologist(state)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# =============================================================================
|
| 60 |
-
# CONDITIONAL ROUTING FUNCTIONS
|
| 61 |
-
# =============================================================================
|
| 62 |
-
|
| 63 |
def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
Routes to Vision Specialist if:
|
| 68 |
-
- stage is "targeted" AND
|
| 69 |
-
- route_to_vision is True (i.e., we have lab data to process)
|
| 70 |
-
|
| 71 |
-
Otherwise routes directly to Clinical Pharmacologist (empirical path).
|
| 72 |
-
"""
|
| 73 |
-
stage = state.get("stage", "empirical")
|
| 74 |
-
has_lab_data = state.get("route_to_vision", False)
|
| 75 |
-
|
| 76 |
-
if stage == "targeted" and has_lab_data:
|
| 77 |
-
logger.info("Graph: Routing to Vision Specialist (targeted path)")
|
| 78 |
return "vision_specialist"
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
return "clinical_pharmacologist"
|
| 82 |
|
| 83 |
|
| 84 |
def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
|
| 85 |
-
"""
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
Routes to Trend Analyst if:
|
| 89 |
-
- route_to_trend_analyst is True (i.e., we have MIC data to analyze)
|
| 90 |
-
|
| 91 |
-
Otherwise skips to Clinical Pharmacologist.
|
| 92 |
-
"""
|
| 93 |
-
should_analyze_trends = state.get("route_to_trend_analyst", False)
|
| 94 |
-
|
| 95 |
-
if should_analyze_trends:
|
| 96 |
-
logger.info("Graph: Routing to Trend Analyst")
|
| 97 |
return "trend_analyst"
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
return "clinical_pharmacologist"
|
| 101 |
-
|
| 102 |
|
| 103 |
-
# =============================================================================
|
| 104 |
-
# GRAPH CONSTRUCTION
|
| 105 |
-
# =============================================================================
|
| 106 |
|
| 107 |
def build_infection_graph() -> StateGraph:
|
| 108 |
-
"""
|
| 109 |
-
Build the LangGraph StateGraph for the infection lifecycle workflow.
|
| 110 |
-
|
| 111 |
-
Returns:
|
| 112 |
-
Compiled StateGraph ready for execution
|
| 113 |
-
"""
|
| 114 |
-
# Create the graph with InfectionState as the state schema
|
| 115 |
graph = StateGraph(InfectionState)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
graph.add_node("
|
| 119 |
-
graph.add_node("
|
| 120 |
-
graph.add_node("
|
| 121 |
-
graph.add_node("clinical_pharmacologist", clinical_pharmacologist_node)
|
| 122 |
|
| 123 |
-
# Set entry point
|
| 124 |
graph.set_entry_point("intake_historian")
|
| 125 |
|
| 126 |
-
# Add conditional edges from intake_historian
|
| 127 |
graph.add_conditional_edges(
|
| 128 |
"intake_historian",
|
| 129 |
route_after_intake,
|
| 130 |
-
{
|
| 131 |
-
"vision_specialist": "vision_specialist",
|
| 132 |
-
"clinical_pharmacologist": "clinical_pharmacologist",
|
| 133 |
-
}
|
| 134 |
)
|
| 135 |
-
|
| 136 |
-
# Add conditional edges from vision_specialist
|
| 137 |
graph.add_conditional_edges(
|
| 138 |
"vision_specialist",
|
| 139 |
route_after_vision,
|
| 140 |
-
{
|
| 141 |
-
"trend_analyst": "trend_analyst",
|
| 142 |
-
"clinical_pharmacologist": "clinical_pharmacologist",
|
| 143 |
-
}
|
| 144 |
)
|
| 145 |
|
| 146 |
-
# Add edge from trend_analyst to clinical_pharmacologist
|
| 147 |
graph.add_edge("trend_analyst", "clinical_pharmacologist")
|
| 148 |
-
|
| 149 |
-
# Add edge from clinical_pharmacologist to END
|
| 150 |
graph.add_edge("clinical_pharmacologist", END)
|
| 151 |
|
| 152 |
return graph
|
| 153 |
|
| 154 |
|
| 155 |
-
def
|
| 156 |
"""
|
| 157 |
-
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
"""
|
| 162 |
-
graph = build_infection_graph()
|
| 163 |
-
return graph.compile()
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# =============================================================================
|
| 167 |
-
# EXECUTION HELPERS
|
| 168 |
-
# =============================================================================
|
| 169 |
-
|
| 170 |
-
def run_pipeline(
|
| 171 |
-
patient_data: dict,
|
| 172 |
-
labs_raw_text: str | None = None,
|
| 173 |
-
) -> InfectionState:
|
| 174 |
-
"""
|
| 175 |
-
Run the full infection lifecycle pipeline.
|
| 176 |
-
|
| 177 |
-
This is the main entry point for executing the multi-agent workflow.
|
| 178 |
-
|
| 179 |
-
Args:
|
| 180 |
-
patient_data: Dict containing patient information:
|
| 181 |
-
- age_years: Patient age
|
| 182 |
-
- weight_kg: Patient weight
|
| 183 |
-
- sex: "male" or "female"
|
| 184 |
-
- serum_creatinine_mg_dl: Serum creatinine (optional)
|
| 185 |
-
- medications: List of current medications
|
| 186 |
-
- allergies: List of allergies
|
| 187 |
-
- comorbidities: List of comorbidities
|
| 188 |
-
- infection_site: Site of infection
|
| 189 |
-
- suspected_source: Suspected pathogen/source
|
| 190 |
-
|
| 191 |
-
labs_raw_text: Raw text from lab report (if available).
|
| 192 |
-
If provided, triggers targeted (Stage 2) pathway.
|
| 193 |
-
|
| 194 |
-
Returns:
|
| 195 |
-
Final InfectionState with recommendation
|
| 196 |
-
|
| 197 |
-
Example:
|
| 198 |
-
>>> state = run_pipeline(
|
| 199 |
-
... patient_data={
|
| 200 |
-
... "age_years": 65,
|
| 201 |
-
... "weight_kg": 70,
|
| 202 |
-
... "sex": "male",
|
| 203 |
-
... "serum_creatinine_mg_dl": 1.2,
|
| 204 |
-
... "medications": ["metformin", "lisinopril"],
|
| 205 |
-
... "allergies": ["penicillin"],
|
| 206 |
-
... "infection_site": "urinary",
|
| 207 |
-
... "suspected_source": "community UTI",
|
| 208 |
-
... },
|
| 209 |
-
... labs_raw_text="E. coli isolated. Ciprofloxacin MIC: 0.5 mg/L (S)"
|
| 210 |
-
... )
|
| 211 |
-
>>> print(state["recommendation"]["primary_antibiotic"])
|
| 212 |
-
"""
|
| 213 |
-
# Build initial state from patient data
|
| 214 |
initial_state: InfectionState = {
|
| 215 |
"age_years": patient_data.get("age_years"),
|
| 216 |
"weight_kg": patient_data.get("weight_kg"),
|
|
@@ -224,59 +90,34 @@ def run_pipeline(
|
|
| 224 |
"suspected_source": patient_data.get("suspected_source"),
|
| 225 |
"country_or_region": patient_data.get("country_or_region"),
|
| 226 |
"vitals": patient_data.get("vitals", {}),
|
|
|
|
| 227 |
}
|
| 228 |
|
| 229 |
-
# Add lab data if provided
|
| 230 |
if labs_raw_text:
|
| 231 |
initial_state["labs_raw_text"] = labs_raw_text
|
| 232 |
-
initial_state["stage"] = "targeted"
|
| 233 |
-
else:
|
| 234 |
-
initial_state["stage"] = "empirical"
|
| 235 |
-
|
| 236 |
-
# Compile and run the graph
|
| 237 |
-
logger.info(f"Starting pipeline execution (stage: {initial_state['stage']})")
|
| 238 |
-
|
| 239 |
-
compiled_graph = compile_graph()
|
| 240 |
-
final_state = compiled_graph.invoke(initial_state)
|
| 241 |
-
|
| 242 |
-
logger.info("Pipeline execution complete")
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
return final_state
|
| 245 |
|
| 246 |
|
| 247 |
def run_empirical_pipeline(patient_data: dict) -> InfectionState:
|
| 248 |
-
"""
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
Shorthand for run_pipeline without lab data.
|
| 252 |
-
"""
|
| 253 |
-
return run_pipeline(patient_data, labs_raw_text=None)
|
| 254 |
|
| 255 |
|
| 256 |
def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
|
| 257 |
-
"""
|
| 258 |
-
Run Stage 2 (Targeted) pipeline with lab data.
|
| 259 |
-
|
| 260 |
-
Shorthand for run_pipeline with lab data.
|
| 261 |
-
"""
|
| 262 |
return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
|
| 263 |
|
| 264 |
|
| 265 |
-
# =============================================================================
|
| 266 |
-
# VISUALIZATION (for debugging)
|
| 267 |
-
# =============================================================================
|
| 268 |
-
|
| 269 |
def get_graph_mermaid() -> str:
|
| 270 |
-
"""
|
| 271 |
-
Get Mermaid diagram representation of the graph.
|
| 272 |
-
|
| 273 |
-
Useful for documentation and debugging.
|
| 274 |
-
"""
|
| 275 |
-
graph = build_infection_graph()
|
| 276 |
try:
|
| 277 |
-
return
|
| 278 |
except Exception:
|
| 279 |
-
# Fallback: return manual diagram
|
| 280 |
return """
|
| 281 |
graph TD
|
| 282 |
A[intake_historian] --> B{route_after_intake}
|
|
@@ -288,13 +129,3 @@ graph TD
|
|
| 288 |
F --> E
|
| 289 |
E --> G[END]
|
| 290 |
"""
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
__all__ = [
|
| 294 |
-
"build_infection_graph",
|
| 295 |
-
"compile_graph",
|
| 296 |
-
"run_pipeline",
|
| 297 |
-
"run_empirical_pipeline",
|
| 298 |
-
"run_targeted_pipeline",
|
| 299 |
-
"get_graph_mermaid",
|
| 300 |
-
]
|
|
|
|
| 1 |
"""
|
| 2 |
+
LangGraph orchestrator for the infection lifecycle workflow.
|
| 3 |
|
| 4 |
+
Stage 1 (empirical - no lab results):
|
| 5 |
+
Intake Historian → Clinical Pharmacologist
|
| 6 |
|
| 7 |
+
Stage 2 (targeted - lab results available):
|
| 8 |
+
Intake Historian → Vision Specialist → [Trend Analyst →] Clinical Pharmacologist
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
|
|
|
|
|
|
| 11 |
import logging
|
| 12 |
from typing import Literal
|
| 13 |
|
|
|
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
|
| 28 |
+
"""Route to Vision Specialist if we have lab text to parse; otherwise go straight to pharmacologist."""
|
| 29 |
+
if state.get("stage") == "targeted" and state.get("route_to_vision"):
|
| 30 |
+
logger.info("Graph: routing to Vision Specialist (targeted path)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
return "vision_specialist"
|
| 32 |
+
logger.info("Graph: routing to Clinical Pharmacologist (empirical path)")
|
| 33 |
+
return "clinical_pharmacologist"
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
|
| 37 |
+
"""Route to Trend Analyst if Vision Specialist extracted MIC values."""
|
| 38 |
+
if state.get("route_to_trend_analyst"):
|
| 39 |
+
logger.info("Graph: routing to Trend Analyst")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
return "trend_analyst"
|
| 41 |
+
logger.info("Graph: skipping Trend Analyst (no MIC data)")
|
| 42 |
+
return "clinical_pharmacologist"
|
|
|
|
|
|
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def build_infection_graph() -> StateGraph:
|
| 46 |
+
"""Build and return the compiled LangGraph for the infection pipeline."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
graph = StateGraph(InfectionState)
|
| 48 |
|
| 49 |
+
graph.add_node("intake_historian", run_intake_historian)
|
| 50 |
+
graph.add_node("vision_specialist", run_vision_specialist)
|
| 51 |
+
graph.add_node("trend_analyst", run_trend_analyst)
|
| 52 |
+
graph.add_node("clinical_pharmacologist", run_clinical_pharmacologist)
|
|
|
|
| 53 |
|
|
|
|
| 54 |
graph.set_entry_point("intake_historian")
|
| 55 |
|
|
|
|
| 56 |
graph.add_conditional_edges(
|
| 57 |
"intake_historian",
|
| 58 |
route_after_intake,
|
| 59 |
+
{"vision_specialist": "vision_specialist", "clinical_pharmacologist": "clinical_pharmacologist"},
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
|
|
|
|
|
|
| 61 |
graph.add_conditional_edges(
|
| 62 |
"vision_specialist",
|
| 63 |
route_after_vision,
|
| 64 |
+
{"trend_analyst": "trend_analyst", "clinical_pharmacologist": "clinical_pharmacologist"},
|
|
|
|
|
|
|
|
|
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
graph.add_edge("trend_analyst", "clinical_pharmacologist")
|
|
|
|
|
|
|
| 68 |
graph.add_edge("clinical_pharmacologist", END)
|
| 69 |
|
| 70 |
return graph
|
| 71 |
|
| 72 |
|
| 73 |
+
def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> InfectionState:
|
| 74 |
"""
|
| 75 |
+
Run the full infection pipeline and return the final state.
|
| 76 |
|
| 77 |
+
Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
|
| 78 |
+
Without it, only the empirical (Stage 1) pathway runs.
|
| 79 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
initial_state: InfectionState = {
|
| 81 |
"age_years": patient_data.get("age_years"),
|
| 82 |
"weight_kg": patient_data.get("weight_kg"),
|
|
|
|
| 90 |
"suspected_source": patient_data.get("suspected_source"),
|
| 91 |
"country_or_region": patient_data.get("country_or_region"),
|
| 92 |
"vitals": patient_data.get("vitals", {}),
|
| 93 |
+
"stage": "targeted" if labs_raw_text else "empirical",
|
| 94 |
}
|
| 95 |
|
|
|
|
| 96 |
if labs_raw_text:
|
| 97 |
initial_state["labs_raw_text"] = labs_raw_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
logger.info(f"Starting pipeline (stage: {initial_state['stage']})")
|
| 100 |
+
compiled = build_infection_graph().compile()
|
| 101 |
+
final_state = compiled.invoke(initial_state)
|
| 102 |
+
logger.info("Pipeline complete")
|
| 103 |
return final_state
|
| 104 |
|
| 105 |
|
| 106 |
def run_empirical_pipeline(patient_data: dict) -> InfectionState:
|
| 107 |
+
"""Shorthand for run_pipeline without lab data (Stage 1)."""
|
| 108 |
+
return run_pipeline(patient_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
|
| 112 |
+
"""Shorthand for run_pipeline with lab data (Stage 2)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
def get_graph_mermaid() -> str:
|
| 117 |
+
"""Return a Mermaid diagram of the graph (for documentation and debugging)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
try:
|
| 119 |
+
return build_infection_graph().compile().get_graph().draw_mermaid()
|
| 120 |
except Exception:
|
|
|
|
| 121 |
return """
|
| 122 |
graph TD
|
| 123 |
A[intake_historian] --> B{route_after_intake}
|
|
|
|
| 129 |
F --> E
|
| 130 |
E --> G[END]
|
| 131 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/loader.py
CHANGED
|
@@ -1,22 +1,17 @@
|
|
| 1 |
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
import logging
|
| 5 |
from functools import lru_cache
|
| 6 |
-
from typing import Any, Callable, Dict, Literal, Optional
|
| 7 |
|
| 8 |
from .config import get_settings
|
| 9 |
|
| 10 |
-
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
TextBackend = Literal["vertex", "local"]
|
| 14 |
TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
|
| 15 |
|
| 16 |
|
| 17 |
-
def _resolve_backend(
|
| 18 |
-
requested: Optional[TextBackend],
|
| 19 |
-
) -> TextBackend:
|
| 20 |
settings = get_settings()
|
| 21 |
backend = requested or settings.default_backend # type: ignore[assignment]
|
| 22 |
if backend == "vertex" and not settings.use_vertex:
|
|
@@ -27,23 +22,16 @@ def _resolve_backend(
|
|
| 27 |
|
| 28 |
@lru_cache(maxsize=8)
|
| 29 |
def _get_vertex_chat_model(model_name: TextModelName):
|
| 30 |
-
"""
|
| 31 |
-
Lazily construct a Vertex AI chat model via langchain-google-vertexai.
|
| 32 |
-
|
| 33 |
-
Returns an object with an .invoke(str) method; we wrap this in a simple
|
| 34 |
-
callable for downstream use.
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
try:
|
| 38 |
from langchain_google_vertexai import ChatVertexAI
|
| 39 |
-
except Exception as exc:
|
| 40 |
raise RuntimeError(
|
| 41 |
"langchain-google-vertexai is not available; "
|
| 42 |
"install it or switch MEDIC_DEFAULT_BACKEND=local."
|
| 43 |
) from exc
|
| 44 |
|
| 45 |
settings = get_settings()
|
| 46 |
-
|
| 47 |
if settings.vertex_project_id is None:
|
| 48 |
raise RuntimeError(
|
| 49 |
"MEDIC_VERTEX_PROJECT_ID is not set. "
|
|
@@ -56,40 +44,28 @@ def _get_vertex_chat_model(model_name: TextModelName):
|
|
| 56 |
"txgemma_9b": settings.vertex_txgemma_9b_model,
|
| 57 |
"txgemma_2b": settings.vertex_txgemma_2b_model,
|
| 58 |
}
|
| 59 |
-
model_id = model_id_map[model_name]
|
| 60 |
|
| 61 |
llm = ChatVertexAI(
|
| 62 |
-
model=
|
| 63 |
project=settings.vertex_project_id,
|
| 64 |
location=settings.vertex_location,
|
| 65 |
temperature=0.2,
|
| 66 |
)
|
| 67 |
|
| 68 |
def _call(prompt: str, **kwargs: Any) -> str:
|
| 69 |
-
"""Thin wrapper returning plain text from ChatVertexAI."""
|
| 70 |
-
|
| 71 |
result = llm.invoke(prompt, **kwargs)
|
| 72 |
-
|
| 73 |
-
content = getattr(result, "content", result)
|
| 74 |
-
return str(content)
|
| 75 |
|
| 76 |
return _call
|
| 77 |
|
| 78 |
|
| 79 |
@lru_cache(maxsize=8)
|
| 80 |
def _get_local_causal_lm(model_name: TextModelName):
|
| 81 |
-
"""
|
| 82 |
-
Lazily load a local transformers model for offline / Kaggle usage.
|
| 83 |
-
|
| 84 |
-
Assumes model paths are provided via MEDIC_LOCAL_* env vars and that
|
| 85 |
-
the appropriate model weights are available in the environment.
|
| 86 |
-
"""
|
| 87 |
-
|
| 88 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 89 |
import torch
|
| 90 |
|
| 91 |
settings = get_settings()
|
| 92 |
-
|
| 93 |
model_path_map: Dict[TextModelName, Optional[str]] = {
|
| 94 |
"medgemma_4b": settings.local_medgemma_4b_model,
|
| 95 |
"medgemma_27b": settings.local_medgemma_27b_model,
|
|
@@ -101,31 +77,19 @@ def _get_local_causal_lm(model_name: TextModelName):
|
|
| 101 |
if not model_path:
|
| 102 |
raise RuntimeError(
|
| 103 |
f"No local model path configured for {model_name}. "
|
| 104 |
-
|
| 105 |
)
|
| 106 |
|
| 107 |
-
load_kwargs: Dict[str, Any] = {
|
| 108 |
-
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
# Optional 4-bit quantization via bitsandbytes
|
| 112 |
-
if get_settings().quantization == "4bit":
|
| 113 |
load_kwargs["load_in_4bit"] = True
|
| 114 |
|
| 115 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 116 |
model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
| 117 |
|
| 118 |
-
def _call(
|
| 119 |
-
|
| 120 |
-
max_new_tokens: int = 512,
|
| 121 |
-
temperature: float = 0.2,
|
| 122 |
-
**generate_kwargs: Any,
|
| 123 |
-
) -> str:
|
| 124 |
-
inputs = tokenizer(prompt, return_tensors="pt")
|
| 125 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 126 |
-
|
| 127 |
do_sample = temperature > 0
|
| 128 |
-
|
| 129 |
with torch.no_grad():
|
| 130 |
output_ids = model.generate(
|
| 131 |
**inputs,
|
|
@@ -134,11 +98,9 @@ def _get_local_causal_lm(model_name: TextModelName):
|
|
| 134 |
max_new_tokens=max_new_tokens,
|
| 135 |
**generate_kwargs,
|
| 136 |
)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 141 |
-
return text.strip()
|
| 142 |
|
| 143 |
return _call
|
| 144 |
|
|
@@ -148,22 +110,9 @@ def get_text_model(
|
|
| 148 |
model_name: TextModelName = "medgemma_4b",
|
| 149 |
backend: Optional[TextBackend] = None,
|
| 150 |
) -> Callable[..., str]:
|
| 151 |
-
"""
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
Example:
|
| 155 |
-
|
| 156 |
-
from src.loader import get_text_model
|
| 157 |
-
model = get_text_model("medgemma_4b")
|
| 158 |
-
answer = model("Explain ESBL in simple terms.")
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
resolved_backend = _resolve_backend(backend)
|
| 162 |
-
|
| 163 |
-
if resolved_backend == "vertex":
|
| 164 |
-
return _get_vertex_chat_model(model_name)
|
| 165 |
-
else:
|
| 166 |
-
return _get_local_causal_lm(model_name)
|
| 167 |
|
| 168 |
|
| 169 |
def run_inference(
|
|
@@ -174,28 +123,7 @@ def run_inference(
|
|
| 174 |
temperature: float = 0.2,
|
| 175 |
**kwargs: Any,
|
| 176 |
) -> str:
|
| 177 |
-
"""
|
| 178 |
-
Convenience wrapper around `get_text_model`.
|
| 179 |
-
|
| 180 |
-
This is the simplest entry point to use inside agents:
|
| 181 |
-
|
| 182 |
-
from src.loader import run_inference
|
| 183 |
-
text = run_inference(prompt, model_name="medgemma_4b")
|
| 184 |
-
"""
|
| 185 |
-
|
| 186 |
model = get_text_model(model_name=model_name, backend=backend)
|
| 187 |
-
return model(
|
| 188 |
-
prompt,
|
| 189 |
-
max_new_tokens=max_new_tokens,
|
| 190 |
-
temperature=temperature,
|
| 191 |
-
**kwargs,
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
__all__ = [
|
| 196 |
-
"TextBackend",
|
| 197 |
-
"TextModelName",
|
| 198 |
-
"get_text_model",
|
| 199 |
-
"run_inference",
|
| 200 |
-
]
|
| 201 |
|
|
|
|
| 1 |
|
|
|
|
|
|
|
| 2 |
import logging
|
| 3 |
from functools import lru_cache
|
| 4 |
+
from typing import Any, Callable, Dict, Literal, Optional
|
| 5 |
|
| 6 |
from .config import get_settings
|
| 7 |
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
TextBackend = Literal["vertex", "local"]
|
| 11 |
TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
|
| 12 |
|
| 13 |
|
| 14 |
+
def _resolve_backend(requested: Optional[TextBackend]) -> TextBackend:
|
|
|
|
|
|
|
| 15 |
settings = get_settings()
|
| 16 |
backend = requested or settings.default_backend # type: ignore[assignment]
|
| 17 |
if backend == "vertex" and not settings.use_vertex:
|
|
|
|
| 22 |
|
| 23 |
@lru_cache(maxsize=8)
|
| 24 |
def _get_vertex_chat_model(model_name: TextModelName):
|
| 25 |
+
"""Load a Vertex AI chat model and return a callable that takes a prompt string."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
try:
|
| 27 |
from langchain_google_vertexai import ChatVertexAI
|
| 28 |
+
except Exception as exc:
|
| 29 |
raise RuntimeError(
|
| 30 |
"langchain-google-vertexai is not available; "
|
| 31 |
"install it or switch MEDIC_DEFAULT_BACKEND=local."
|
| 32 |
) from exc
|
| 33 |
|
| 34 |
settings = get_settings()
|
|
|
|
| 35 |
if settings.vertex_project_id is None:
|
| 36 |
raise RuntimeError(
|
| 37 |
"MEDIC_VERTEX_PROJECT_ID is not set. "
|
|
|
|
| 44 |
"txgemma_9b": settings.vertex_txgemma_9b_model,
|
| 45 |
"txgemma_2b": settings.vertex_txgemma_2b_model,
|
| 46 |
}
|
|
|
|
| 47 |
|
| 48 |
llm = ChatVertexAI(
|
| 49 |
+
model=model_id_map[model_name],
|
| 50 |
project=settings.vertex_project_id,
|
| 51 |
location=settings.vertex_location,
|
| 52 |
temperature=0.2,
|
| 53 |
)
|
| 54 |
|
| 55 |
def _call(prompt: str, **kwargs: Any) -> str:
|
|
|
|
|
|
|
| 56 |
result = llm.invoke(prompt, **kwargs)
|
| 57 |
+
return str(getattr(result, "content", result))
|
|
|
|
|
|
|
| 58 |
|
| 59 |
return _call
|
| 60 |
|
| 61 |
|
| 62 |
@lru_cache(maxsize=8)
|
| 63 |
def _get_local_causal_lm(model_name: TextModelName):
|
| 64 |
+
"""Load a local HuggingFace causal LM and return a generation callable."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 66 |
import torch
|
| 67 |
|
| 68 |
settings = get_settings()
|
|
|
|
| 69 |
model_path_map: Dict[TextModelName, Optional[str]] = {
|
| 70 |
"medgemma_4b": settings.local_medgemma_4b_model,
|
| 71 |
"medgemma_27b": settings.local_medgemma_27b_model,
|
|
|
|
| 77 |
if not model_path:
|
| 78 |
raise RuntimeError(
|
| 79 |
f"No local model path configured for {model_name}. "
|
| 80 |
+
"Set MEDIC_LOCAL_*_MODEL or use the Vertex backend."
|
| 81 |
)
|
| 82 |
|
| 83 |
+
load_kwargs: Dict[str, Any] = {"device_map": "auto"}
|
| 84 |
+
if settings.quantization == "4bit":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
load_kwargs["load_in_4bit"] = True
|
| 86 |
|
| 87 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 88 |
model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
| 89 |
|
| 90 |
+
def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
|
| 91 |
+
inputs = {k: v.to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
do_sample = temperature > 0
|
|
|
|
| 93 |
with torch.no_grad():
|
| 94 |
output_ids = model.generate(
|
| 95 |
**inputs,
|
|
|
|
| 98 |
max_new_tokens=max_new_tokens,
|
| 99 |
**generate_kwargs,
|
| 100 |
)
|
| 101 |
+
# Decode only the newly generated tokens, not the input prompt
|
| 102 |
+
generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
|
| 103 |
+
return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return _call
|
| 106 |
|
|
|
|
| 110 |
model_name: TextModelName = "medgemma_4b",
|
| 111 |
backend: Optional[TextBackend] = None,
|
| 112 |
) -> Callable[..., str]:
|
| 113 |
+
"""Return a cached callable for the requested model and backend."""
|
| 114 |
+
resolved = _resolve_backend(backend)
|
| 115 |
+
return _get_vertex_chat_model(model_name) if resolved == "vertex" else _get_local_causal_lm(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
def run_inference(
|
|
|
|
| 123 |
temperature: float = 0.2,
|
| 124 |
**kwargs: Any,
|
| 125 |
) -> str:
|
| 126 |
+
"""Run inference with the specified model. This is the primary entry point for agents."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
model = get_text_model(model_name=model_name, backend=backend)
|
| 128 |
+
return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
src/prompts.py
CHANGED
|
@@ -1,18 +1,7 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Prompt templates for Med-I-C multi-agent system.
|
| 3 |
-
|
| 4 |
-
Each agent has a specific role in the infection lifecycle workflow:
|
| 5 |
-
- Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
|
| 6 |
-
- Agent 2: Vision Specialist - Extract structured data from lab reports (images/PDFs)
|
| 7 |
-
- Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
|
| 8 |
-
- Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
|
| 9 |
-
"""
|
| 10 |
|
| 11 |
-
from __future__ import annotations
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
# AGENT 1: INTAKE HISTORIAN
|
| 15 |
-
# =============================================================================
|
| 16 |
|
| 17 |
INTAKE_HISTORIAN_SYSTEM = """You are an expert clinical intake specialist. Your role is to:
|
| 18 |
|
|
@@ -66,9 +55,7 @@ RAG CONTEXT (Relevant Guidelines):
|
|
| 66 |
Provide your structured assessment following the system instructions."""
|
| 67 |
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
# AGENT 2: VISION SPECIALIST
|
| 71 |
-
# =============================================================================
|
| 72 |
|
| 73 |
VISION_SPECIALIST_SYSTEM = """You are an expert medical laboratory data extraction specialist. Your role is to:
|
| 74 |
|
|
@@ -131,9 +118,7 @@ Flag any critical findings that require urgent attention.
|
|
| 131 |
Provide your structured extraction following the system instructions."""
|
| 132 |
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
# AGENT 3: TREND ANALYST
|
| 136 |
-
# =============================================================================
|
| 137 |
|
| 138 |
TREND_ANALYST_SYSTEM = """You are an expert antimicrobial resistance trend analyst. Your role is to:
|
| 139 |
|
|
@@ -195,9 +180,7 @@ Analyze the trend, calculate risk level, and provide recommendations.
|
|
| 195 |
Follow the system instructions for output format."""
|
| 196 |
|
| 197 |
|
| 198 |
-
#
|
| 199 |
-
# AGENT 4: CLINICAL PHARMACOLOGIST
|
| 200 |
-
# =============================================================================
|
| 201 |
|
| 202 |
CLINICAL_PHARMACOLOGIST_SYSTEM = """You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:
|
| 203 |
|
|
@@ -291,9 +274,7 @@ Provide your final recommendation following the system instructions.
|
|
| 291 |
Ensure all safety checks are performed and documented."""
|
| 292 |
|
| 293 |
|
| 294 |
-
#
|
| 295 |
-
# TXGEMMA SAFETY CHECKER (Supplementary)
|
| 296 |
-
# =============================================================================
|
| 297 |
|
| 298 |
TXGEMMA_SAFETY_PROMPT = """Evaluate the safety profile of the following antibiotic prescription:
|
| 299 |
|
|
@@ -315,9 +296,7 @@ Evaluate for:
|
|
| 315 |
Provide a brief safety assessment (2-3 sentences) and a risk rating (LOW/MODERATE/HIGH)."""
|
| 316 |
|
| 317 |
|
| 318 |
-
#
|
| 319 |
-
# HELPER TEMPLATES
|
| 320 |
-
# =============================================================================
|
| 321 |
|
| 322 |
ERROR_RECOVERY_PROMPT = """The previous agent encountered an error or produced invalid output.
|
| 323 |
|
|
@@ -338,18 +317,3 @@ CLINICAL SCENARIO:
|
|
| 338 |
- Local resistance patterns: {local_resistance}
|
| 339 |
|
| 340 |
Recommend appropriate empirical therapy following WHO AWaRe principles."""
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
__all__ = [
|
| 344 |
-
"INTAKE_HISTORIAN_SYSTEM",
|
| 345 |
-
"INTAKE_HISTORIAN_PROMPT",
|
| 346 |
-
"VISION_SPECIALIST_SYSTEM",
|
| 347 |
-
"VISION_SPECIALIST_PROMPT",
|
| 348 |
-
"TREND_ANALYST_SYSTEM",
|
| 349 |
-
"TREND_ANALYST_PROMPT",
|
| 350 |
-
"CLINICAL_PHARMACOLOGIST_SYSTEM",
|
| 351 |
-
"CLINICAL_PHARMACOLOGIST_PROMPT",
|
| 352 |
-
"TXGEMMA_SAFETY_PROMPT",
|
| 353 |
-
"ERROR_RECOVERY_PROMPT",
|
| 354 |
-
"FALLBACK_EMPIRICAL_PROMPT",
|
| 355 |
-
]
|
|
|
|
| 1 |
+
"""Prompt templates for each agent in the Med-I-C pipeline."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
|
|
|
| 3 |
|
| 4 |
+
# --- Agent 1: Intake Historian ---
|
|
|
|
|
|
|
| 5 |
|
| 6 |
INTAKE_HISTORIAN_SYSTEM = """You are an expert clinical intake specialist. Your role is to:
|
| 7 |
|
|
|
|
| 55 |
Provide your structured assessment following the system instructions."""
|
| 56 |
|
| 57 |
|
| 58 |
+
# --- Agent 2: Vision Specialist ---
|
|
|
|
|
|
|
| 59 |
|
| 60 |
VISION_SPECIALIST_SYSTEM = """You are an expert medical laboratory data extraction specialist. Your role is to:
|
| 61 |
|
|
|
|
| 118 |
Provide your structured extraction following the system instructions."""
|
| 119 |
|
| 120 |
|
| 121 |
+
# --- Agent 3: Trend Analyst ---
|
|
|
|
|
|
|
| 122 |
|
| 123 |
TREND_ANALYST_SYSTEM = """You are an expert antimicrobial resistance trend analyst. Your role is to:
|
| 124 |
|
|
|
|
| 180 |
Follow the system instructions for output format."""
|
| 181 |
|
| 182 |
|
| 183 |
+
# --- Agent 4: Clinical Pharmacologist ---
|
|
|
|
|
|
|
| 184 |
|
| 185 |
CLINICAL_PHARMACOLOGIST_SYSTEM = """You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:
|
| 186 |
|
|
|
|
| 274 |
Ensure all safety checks are performed and documented."""
|
| 275 |
|
| 276 |
|
| 277 |
+
# --- TxGemma safety check (supplementary, not primary decision-making) ---
|
|
|
|
|
|
|
| 278 |
|
| 279 |
TXGEMMA_SAFETY_PROMPT = """Evaluate the safety profile of the following antibiotic prescription:
|
| 280 |
|
|
|
|
| 296 |
Provide a brief safety assessment (2-3 sentences) and a risk rating (LOW/MODERATE/HIGH)."""
|
| 297 |
|
| 298 |
|
| 299 |
+
# --- Fallback templates ---
|
|
|
|
|
|
|
| 300 |
|
| 301 |
ERROR_RECOVERY_PROMPT = """The previous agent encountered an error or produced invalid output.
|
| 302 |
|
|
|
|
| 317 |
- Local resistance patterns: {local_resistance}
|
| 318 |
|
| 319 |
Recommend appropriate empirical therapy following WHO AWaRe principles."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rag.py
CHANGED
|
@@ -1,116 +1,80 @@
|
|
| 1 |
"""
|
| 2 |
-
RAG
|
| 3 |
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
-
|
| 7 |
-
- drug_safety: Drug interactions
|
| 8 |
-
- pathogen_resistance:
|
| 9 |
"""
|
| 10 |
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
import logging
|
| 14 |
-
from pathlib import Path
|
| 15 |
from typing import Any, Dict, List, Optional
|
| 16 |
|
| 17 |
from .config import get_settings
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
# =============================================================================
|
| 23 |
-
# CHROMA CLIENT & EMBEDDING SETUP
|
| 24 |
-
# =============================================================================
|
| 25 |
-
|
| 26 |
_chroma_client = None
|
| 27 |
_embedding_function = None
|
| 28 |
|
| 29 |
|
| 30 |
def get_chroma_client():
|
| 31 |
-
"""
|
| 32 |
global _chroma_client
|
| 33 |
if _chroma_client is None:
|
| 34 |
import chromadb
|
| 35 |
-
|
| 36 |
-
settings = get_settings()
|
| 37 |
-
chroma_path = settings.chroma_db_dir
|
| 38 |
chroma_path.mkdir(parents=True, exist_ok=True)
|
| 39 |
_chroma_client = chromadb.PersistentClient(path=str(chroma_path))
|
| 40 |
return _chroma_client
|
| 41 |
|
| 42 |
|
| 43 |
def get_embedding_function():
|
| 44 |
-
"""
|
| 45 |
global _embedding_function
|
| 46 |
if _embedding_function is None:
|
| 47 |
from chromadb.utils import embedding_functions
|
| 48 |
-
|
| 49 |
-
|
| 50 |
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 51 |
-
model_name=
|
| 52 |
)
|
| 53 |
return _embedding_function
|
| 54 |
|
| 55 |
|
| 56 |
def get_collection(name: str):
|
| 57 |
-
"""
|
| 58 |
-
Get a ChromaDB collection by name.
|
| 59 |
-
|
| 60 |
-
Returns None if collection doesn't exist.
|
| 61 |
-
"""
|
| 62 |
-
client = get_chroma_client()
|
| 63 |
-
ef = get_embedding_function()
|
| 64 |
-
|
| 65 |
try:
|
| 66 |
-
return
|
| 67 |
except Exception:
|
| 68 |
logger.warning(f"Collection '{name}' not found")
|
| 69 |
return None
|
| 70 |
|
| 71 |
|
| 72 |
-
# =============================================================================
|
| 73 |
-
# COLLECTION-SPECIFIC RETRIEVERS
|
| 74 |
-
# =============================================================================
|
| 75 |
-
|
| 76 |
def search_antibiotic_guidelines(
|
| 77 |
query: str,
|
| 78 |
n_results: int = 5,
|
| 79 |
pathogen_filter: Optional[str] = None,
|
| 80 |
) -> List[Dict[str, Any]]:
|
| 81 |
-
"""
|
| 82 |
-
Search antibiotic treatment guidelines.
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
query: Search query
|
| 86 |
-
n_results: Number of results to return
|
| 87 |
-
pathogen_filter: Optional pathogen type filter (e.g., "ESBL-E", "CRE")
|
| 88 |
-
|
| 89 |
-
Returns:
|
| 90 |
-
List of relevant guideline excerpts with metadata
|
| 91 |
-
"""
|
| 92 |
collection = get_collection("idsa_treatment_guidelines")
|
| 93 |
if collection is None:
|
| 94 |
-
logger.warning("idsa_treatment_guidelines collection not available")
|
| 95 |
return []
|
| 96 |
-
|
| 97 |
-
where_filter = None
|
| 98 |
-
if pathogen_filter:
|
| 99 |
-
where_filter = {"pathogen_type": pathogen_filter}
|
| 100 |
-
|
| 101 |
try:
|
|
|
|
| 102 |
results = collection.query(
|
| 103 |
query_texts=[query],
|
| 104 |
n_results=n_results,
|
| 105 |
-
where=
|
| 106 |
include=["documents", "metadatas", "distances"],
|
| 107 |
)
|
|
|
|
| 108 |
except Exception as e:
|
| 109 |
logger.error(f"Error querying guidelines: {e}")
|
| 110 |
return []
|
| 111 |
|
| 112 |
-
return _format_results(results)
|
| 113 |
-
|
| 114 |
|
| 115 |
def search_mic_breakpoints(
|
| 116 |
query: str,
|
|
@@ -118,79 +82,45 @@ def search_mic_breakpoints(
|
|
| 118 |
organism: Optional[str] = None,
|
| 119 |
antibiotic: Optional[str] = None,
|
| 120 |
) -> List[Dict[str, Any]]:
|
| 121 |
-
"""
|
| 122 |
-
Search MIC breakpoint reference documentation.
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
query: Search query
|
| 126 |
-
n_results: Number of results
|
| 127 |
-
organism: Optional organism name filter
|
| 128 |
-
antibiotic: Optional antibiotic name filter
|
| 129 |
-
|
| 130 |
-
Returns:
|
| 131 |
-
List of relevant breakpoint information
|
| 132 |
-
"""
|
| 133 |
collection = get_collection("mic_reference_docs")
|
| 134 |
if collection is None:
|
| 135 |
-
logger.warning("mic_reference_docs collection not available")
|
| 136 |
return []
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
enhanced_query = query
|
| 140 |
-
if organism:
|
| 141 |
-
enhanced_query = f"{organism} {enhanced_query}"
|
| 142 |
-
if antibiotic:
|
| 143 |
-
enhanced_query = f"{antibiotic} {enhanced_query}"
|
| 144 |
-
|
| 145 |
try:
|
| 146 |
results = collection.query(
|
| 147 |
query_texts=[enhanced_query],
|
| 148 |
n_results=n_results,
|
| 149 |
include=["documents", "metadatas", "distances"],
|
| 150 |
)
|
|
|
|
| 151 |
except Exception as e:
|
| 152 |
logger.error(f"Error querying breakpoints: {e}")
|
| 153 |
return []
|
| 154 |
|
| 155 |
-
return _format_results(results)
|
| 156 |
-
|
| 157 |
|
| 158 |
def search_drug_safety(
|
| 159 |
query: str,
|
| 160 |
n_results: int = 5,
|
| 161 |
drug_name: Optional[str] = None,
|
| 162 |
) -> List[Dict[str, Any]]:
|
| 163 |
-
"""
|
| 164 |
-
Search drug safety information (interactions, warnings, contraindications).
|
| 165 |
-
|
| 166 |
-
Args:
|
| 167 |
-
query: Search query
|
| 168 |
-
n_results: Number of results
|
| 169 |
-
drug_name: Optional drug name to focus search
|
| 170 |
-
|
| 171 |
-
Returns:
|
| 172 |
-
List of relevant safety information
|
| 173 |
-
"""
|
| 174 |
collection = get_collection("drug_safety")
|
| 175 |
if collection is None:
|
| 176 |
-
# Fallback: try existing collections
|
| 177 |
-
logger.warning("drug_safety collection not available")
|
| 178 |
return []
|
| 179 |
-
|
| 180 |
enhanced_query = f"{drug_name} {query}" if drug_name else query
|
| 181 |
-
|
| 182 |
try:
|
| 183 |
results = collection.query(
|
| 184 |
query_texts=[enhanced_query],
|
| 185 |
n_results=n_results,
|
| 186 |
include=["documents", "metadatas", "distances"],
|
| 187 |
)
|
|
|
|
| 188 |
except Exception as e:
|
| 189 |
logger.error(f"Error querying drug safety: {e}")
|
| 190 |
return []
|
| 191 |
|
| 192 |
-
return _format_results(results)
|
| 193 |
-
|
| 194 |
|
| 195 |
def search_resistance_patterns(
|
| 196 |
query: str,
|
|
@@ -198,45 +128,22 @@ def search_resistance_patterns(
|
|
| 198 |
organism: Optional[str] = None,
|
| 199 |
region: Optional[str] = None,
|
| 200 |
) -> List[Dict[str, Any]]:
|
| 201 |
-
"""
|
| 202 |
-
Search pathogen resistance pattern data.
|
| 203 |
-
|
| 204 |
-
Args:
|
| 205 |
-
query: Search query
|
| 206 |
-
n_results: Number of results
|
| 207 |
-
organism: Optional organism filter
|
| 208 |
-
region: Optional geographic region filter
|
| 209 |
-
|
| 210 |
-
Returns:
|
| 211 |
-
List of relevant resistance data
|
| 212 |
-
"""
|
| 213 |
collection = get_collection("pathogen_resistance")
|
| 214 |
if collection is None:
|
| 215 |
-
logger.warning("pathogen_resistance collection not available")
|
| 216 |
return []
|
| 217 |
-
|
| 218 |
-
enhanced_query = query
|
| 219 |
-
if organism:
|
| 220 |
-
enhanced_query = f"{organism} {enhanced_query}"
|
| 221 |
-
if region:
|
| 222 |
-
enhanced_query = f"{region} {enhanced_query}"
|
| 223 |
-
|
| 224 |
try:
|
| 225 |
results = collection.query(
|
| 226 |
query_texts=[enhanced_query],
|
| 227 |
n_results=n_results,
|
| 228 |
include=["documents", "metadatas", "distances"],
|
| 229 |
)
|
|
|
|
| 230 |
except Exception as e:
|
| 231 |
logger.error(f"Error querying resistance patterns: {e}")
|
| 232 |
return []
|
| 233 |
|
| 234 |
-
return _format_results(results)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# =============================================================================
|
| 238 |
-
# UNIFIED CONTEXT RETRIEVER
|
| 239 |
-
# =============================================================================
|
| 240 |
|
| 241 |
def get_context_for_agent(
|
| 242 |
agent_name: str,
|
|
@@ -245,238 +152,104 @@ def get_context_for_agent(
|
|
| 245 |
n_results: int = 3,
|
| 246 |
) -> str:
|
| 247 |
"""
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
This is the main entry point for agents to retrieve context.
|
| 251 |
-
|
| 252 |
-
Args:
|
| 253 |
-
agent_name: Name of the requesting agent
|
| 254 |
-
query: The primary search query
|
| 255 |
-
patient_context: Optional dict with patient-specific info
|
| 256 |
-
n_results: Number of results per collection
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
| 260 |
"""
|
| 261 |
-
|
| 262 |
-
|
| 263 |
|
| 264 |
if agent_name == "intake_historian":
|
| 265 |
-
|
| 266 |
-
guidelines = search_antibiotic_guidelines(
|
| 267 |
-
query=query,
|
| 268 |
-
n_results=n_results,
|
| 269 |
-
pathogen_filter=patient_context.get("pathogen_type"),
|
| 270 |
-
)
|
| 271 |
if guidelines:
|
| 272 |
-
|
| 273 |
for g in guidelines:
|
| 274 |
-
|
| 275 |
-
|
| 276 |
|
| 277 |
elif agent_name == "vision_specialist":
|
| 278 |
-
|
| 279 |
-
breakpoints = search_mic_breakpoints(
|
| 280 |
-
query=query,
|
| 281 |
-
n_results=n_results,
|
| 282 |
-
organism=patient_context.get("organism"),
|
| 283 |
-
antibiotic=patient_context.get("antibiotic"),
|
| 284 |
-
)
|
| 285 |
if breakpoints:
|
| 286 |
-
|
| 287 |
for b in breakpoints:
|
| 288 |
-
|
| 289 |
|
| 290 |
elif agent_name == "trend_analyst":
|
| 291 |
-
# Get breakpoints and resistance trends
|
| 292 |
breakpoints = search_mic_breakpoints(
|
| 293 |
-
|
| 294 |
-
n_results=n_results,
|
| 295 |
-
)
|
| 296 |
-
resistance = search_resistance_patterns(
|
| 297 |
-
query=query,
|
| 298 |
n_results=n_results,
|
| 299 |
-
organism=patient_context.get("organism"),
|
| 300 |
-
region=patient_context.get("region"),
|
| 301 |
)
|
| 302 |
-
|
| 303 |
if breakpoints:
|
| 304 |
-
|
| 305 |
for b in breakpoints:
|
| 306 |
-
|
| 307 |
-
|
| 308 |
if resistance:
|
| 309 |
-
|
| 310 |
for r in resistance:
|
| 311 |
-
|
| 312 |
|
| 313 |
elif agent_name == "clinical_pharmacologist":
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
query=query,
|
| 317 |
-
n_results=n_results,
|
| 318 |
-
)
|
| 319 |
-
safety = search_drug_safety(
|
| 320 |
-
query=query,
|
| 321 |
-
n_results=n_results,
|
| 322 |
-
drug_name=patient_context.get("proposed_antibiotic"),
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
if guidelines:
|
| 326 |
-
|
| 327 |
for g in guidelines:
|
| 328 |
-
|
| 329 |
-
|
| 330 |
if safety:
|
| 331 |
-
|
| 332 |
for s in safety:
|
| 333 |
-
|
| 334 |
|
| 335 |
else:
|
| 336 |
-
# Generic retrieval
|
| 337 |
guidelines = search_antibiotic_guidelines(query, n_results=n_results)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
context_parts.append(f"- {g['content'][:500]}...")
|
| 341 |
-
|
| 342 |
-
if not context_parts:
|
| 343 |
-
return "No relevant context found in knowledge base."
|
| 344 |
-
|
| 345 |
-
return "\n".join(context_parts)
|
| 346 |
-
|
| 347 |
|
| 348 |
-
|
| 349 |
-
query: str,
|
| 350 |
-
collections: Optional[List[str]] = None,
|
| 351 |
-
n_results_per_collection: int = 3,
|
| 352 |
-
**filters,
|
| 353 |
-
) -> str:
|
| 354 |
-
"""
|
| 355 |
-
Get a combined context string from multiple collections.
|
| 356 |
|
| 357 |
-
This is a simpler interface for general-purpose RAG retrieval.
|
| 358 |
-
|
| 359 |
-
Args:
|
| 360 |
-
query: Search query
|
| 361 |
-
collections: List of collection names to search (defaults to all)
|
| 362 |
-
n_results_per_collection: Results per collection
|
| 363 |
-
**filters: Additional filters (organism, antibiotic, region, etc.)
|
| 364 |
-
|
| 365 |
-
Returns:
|
| 366 |
-
Combined context string
|
| 367 |
-
"""
|
| 368 |
-
default_collections = [
|
| 369 |
-
"idsa_treatment_guidelines",
|
| 370 |
-
"mic_reference_docs",
|
| 371 |
-
]
|
| 372 |
-
collections = collections or default_collections
|
| 373 |
-
|
| 374 |
-
context_parts = []
|
| 375 |
-
|
| 376 |
-
for collection_name in collections:
|
| 377 |
-
if collection_name == "idsa_treatment_guidelines":
|
| 378 |
-
results = search_antibiotic_guidelines(
|
| 379 |
-
query,
|
| 380 |
-
n_results=n_results_per_collection,
|
| 381 |
-
pathogen_filter=filters.get("pathogen_type"),
|
| 382 |
-
)
|
| 383 |
-
elif collection_name == "mic_reference_docs":
|
| 384 |
-
results = search_mic_breakpoints(
|
| 385 |
-
query,
|
| 386 |
-
n_results=n_results_per_collection,
|
| 387 |
-
organism=filters.get("organism"),
|
| 388 |
-
antibiotic=filters.get("antibiotic"),
|
| 389 |
-
)
|
| 390 |
-
elif collection_name == "drug_safety":
|
| 391 |
-
results = search_drug_safety(
|
| 392 |
-
query,
|
| 393 |
-
n_results=n_results_per_collection,
|
| 394 |
-
drug_name=filters.get("drug_name"),
|
| 395 |
-
)
|
| 396 |
-
elif collection_name == "pathogen_resistance":
|
| 397 |
-
results = search_resistance_patterns(
|
| 398 |
-
query,
|
| 399 |
-
n_results=n_results_per_collection,
|
| 400 |
-
organism=filters.get("organism"),
|
| 401 |
-
region=filters.get("region"),
|
| 402 |
-
)
|
| 403 |
-
else:
|
| 404 |
-
continue
|
| 405 |
-
|
| 406 |
-
if results:
|
| 407 |
-
context_parts.append(f"=== {collection_name.upper()} ===")
|
| 408 |
-
for r in results:
|
| 409 |
-
context_parts.append(r["content"])
|
| 410 |
-
context_parts.append(f"[Relevance: {1 - r.get('distance', 0):.2f}]")
|
| 411 |
-
context_parts.append("")
|
| 412 |
-
|
| 413 |
-
return "\n".join(context_parts) if context_parts else "No relevant context found."
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
# =============================================================================
|
| 417 |
-
# HELPER FUNCTIONS
|
| 418 |
-
# =============================================================================
|
| 419 |
|
| 420 |
def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 421 |
-
"""
|
| 422 |
if not results or not results.get("documents"):
|
| 423 |
return []
|
| 424 |
|
| 425 |
-
formatted = []
|
| 426 |
documents = results["documents"][0] if results["documents"] else []
|
| 427 |
metadatas = results.get("metadatas", [[]])[0]
|
| 428 |
distances = results.get("distances", [[]])[0]
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
"content": doc,
|
| 433 |
"metadata": metadatas[i] if i < len(metadatas) else {},
|
| 434 |
"distance": distances[i] if i < len(distances) else None,
|
| 435 |
"source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
|
| 436 |
"relevance_score": 1 - (distances[i] if i < len(distances) else 0),
|
| 437 |
-
}
|
| 438 |
-
|
| 439 |
-
|
| 440 |
|
| 441 |
|
| 442 |
def list_available_collections() -> List[str]:
|
| 443 |
-
"""
|
| 444 |
-
client = get_chroma_client()
|
| 445 |
try:
|
| 446 |
-
|
| 447 |
-
return [c.name for c in collections]
|
| 448 |
except Exception as e:
|
| 449 |
logger.error(f"Error listing collections: {e}")
|
| 450 |
return []
|
| 451 |
|
| 452 |
|
| 453 |
def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
|
| 454 |
-
"""
|
| 455 |
collection = get_collection(name)
|
| 456 |
if collection is None:
|
| 457 |
return None
|
| 458 |
-
|
| 459 |
try:
|
| 460 |
-
return {
|
| 461 |
-
"name": collection.name,
|
| 462 |
-
"count": collection.count(),
|
| 463 |
-
"metadata": collection.metadata,
|
| 464 |
-
}
|
| 465 |
except Exception as e:
|
| 466 |
logger.error(f"Error getting collection info: {e}")
|
| 467 |
return None
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
__all__ = [
|
| 471 |
-
"get_chroma_client",
|
| 472 |
-
"get_embedding_function",
|
| 473 |
-
"get_collection",
|
| 474 |
-
"search_antibiotic_guidelines",
|
| 475 |
-
"search_mic_breakpoints",
|
| 476 |
-
"search_drug_safety",
|
| 477 |
-
"search_resistance_patterns",
|
| 478 |
-
"get_context_for_agent",
|
| 479 |
-
"get_context_string",
|
| 480 |
-
"list_available_collections",
|
| 481 |
-
"get_collection_info",
|
| 482 |
-
]
|
|
|
|
| 1 |
"""
|
| 2 |
+
RAG module for Med-I-C.
|
| 3 |
|
| 4 |
+
Retrieves context from four ChromaDB collections:
|
| 5 |
+
- idsa_treatment_guidelines: IDSA 2024 AMR guidance
|
| 6 |
+
- mic_reference_docs: EUCAST v16.0 breakpoint tables
|
| 7 |
+
- drug_safety: Drug interactions and contraindications
|
| 8 |
+
- pathogen_resistance: ATLAS regional susceptibility data
|
| 9 |
"""
|
| 10 |
|
|
|
|
|
|
|
| 11 |
import logging
|
|
|
|
| 12 |
from typing import Any, Dict, List, Optional
|
| 13 |
|
| 14 |
from .config import get_settings
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
# Module-level singletons; initialized lazily to avoid import-time side effects
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
_chroma_client = None
|
| 20 |
_embedding_function = None
|
| 21 |
|
| 22 |
|
| 23 |
def get_chroma_client():
|
| 24 |
+
"""Return the ChromaDB persistent client, creating it on first call."""
|
| 25 |
global _chroma_client
|
| 26 |
if _chroma_client is None:
|
| 27 |
import chromadb
|
| 28 |
+
chroma_path = get_settings().chroma_db_dir
|
|
|
|
|
|
|
| 29 |
chroma_path.mkdir(parents=True, exist_ok=True)
|
| 30 |
_chroma_client = chromadb.PersistentClient(path=str(chroma_path))
|
| 31 |
return _chroma_client
|
| 32 |
|
| 33 |
|
| 34 |
def get_embedding_function():
|
| 35 |
+
"""Return the SentenceTransformer embedding function, creating it on first call."""
|
| 36 |
global _embedding_function
|
| 37 |
if _embedding_function is None:
|
| 38 |
from chromadb.utils import embedding_functions
|
| 39 |
+
# Use only the model short name (not the full HuggingFace path)
|
| 40 |
+
model_short_name = get_settings().embedding_model_name.split("/")[-1]
|
| 41 |
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 42 |
+
model_name=model_short_name
|
| 43 |
)
|
| 44 |
return _embedding_function
|
| 45 |
|
| 46 |
|
| 47 |
def get_collection(name: str):
|
| 48 |
+
"""Return a ChromaDB collection by name, or None if it does not exist."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
try:
|
| 50 |
+
return get_chroma_client().get_collection(name=name, embedding_function=get_embedding_function())
|
| 51 |
except Exception:
|
| 52 |
logger.warning(f"Collection '{name}' not found")
|
| 53 |
return None
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def search_antibiotic_guidelines(
|
| 57 |
query: str,
|
| 58 |
n_results: int = 5,
|
| 59 |
pathogen_filter: Optional[str] = None,
|
| 60 |
) -> List[Dict[str, Any]]:
|
| 61 |
+
"""Search the IDSA treatment guidelines collection."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
collection = get_collection("idsa_treatment_guidelines")
|
| 63 |
if collection is None:
|
|
|
|
| 64 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
try:
|
| 66 |
+
where = {"pathogen_type": pathogen_filter} if pathogen_filter else None
|
| 67 |
results = collection.query(
|
| 68 |
query_texts=[query],
|
| 69 |
n_results=n_results,
|
| 70 |
+
where=where,
|
| 71 |
include=["documents", "metadatas", "distances"],
|
| 72 |
)
|
| 73 |
+
return _format_results(results)
|
| 74 |
except Exception as e:
|
| 75 |
logger.error(f"Error querying guidelines: {e}")
|
| 76 |
return []
|
| 77 |
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def search_mic_breakpoints(
|
| 80 |
query: str,
|
|
|
|
| 82 |
organism: Optional[str] = None,
|
| 83 |
antibiotic: Optional[str] = None,
|
| 84 |
) -> List[Dict[str, Any]]:
|
| 85 |
+
"""Search the EUCAST MIC breakpoint reference collection."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
collection = get_collection("mic_reference_docs")
|
| 87 |
if collection is None:
|
|
|
|
| 88 |
return []
|
| 89 |
+
# Prepend organism/antibiotic to query to narrow semantic search
|
| 90 |
+
enhanced_query = " ".join(filter(None, [organism, antibiotic, query]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
try:
|
| 92 |
results = collection.query(
|
| 93 |
query_texts=[enhanced_query],
|
| 94 |
n_results=n_results,
|
| 95 |
include=["documents", "metadatas", "distances"],
|
| 96 |
)
|
| 97 |
+
return _format_results(results)
|
| 98 |
except Exception as e:
|
| 99 |
logger.error(f"Error querying breakpoints: {e}")
|
| 100 |
return []
|
| 101 |
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def search_drug_safety(
|
| 104 |
query: str,
|
| 105 |
n_results: int = 5,
|
| 106 |
drug_name: Optional[str] = None,
|
| 107 |
) -> List[Dict[str, Any]]:
|
| 108 |
+
"""Search the drug safety collection (interactions, warnings, contraindications)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
collection = get_collection("drug_safety")
|
| 110 |
if collection is None:
|
|
|
|
|
|
|
| 111 |
return []
|
|
|
|
| 112 |
enhanced_query = f"{drug_name} {query}" if drug_name else query
|
|
|
|
| 113 |
try:
|
| 114 |
results = collection.query(
|
| 115 |
query_texts=[enhanced_query],
|
| 116 |
n_results=n_results,
|
| 117 |
include=["documents", "metadatas", "distances"],
|
| 118 |
)
|
| 119 |
+
return _format_results(results)
|
| 120 |
except Exception as e:
|
| 121 |
logger.error(f"Error querying drug safety: {e}")
|
| 122 |
return []
|
| 123 |
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def search_resistance_patterns(
|
| 126 |
query: str,
|
|
|
|
| 128 |
organism: Optional[str] = None,
|
| 129 |
region: Optional[str] = None,
|
| 130 |
) -> List[Dict[str, Any]]:
|
| 131 |
+
"""Search the ATLAS pathogen resistance collection."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
collection = get_collection("pathogen_resistance")
|
| 133 |
if collection is None:
|
|
|
|
| 134 |
return []
|
| 135 |
+
enhanced_query = " ".join(filter(None, [region, organism, query]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
try:
|
| 137 |
results = collection.query(
|
| 138 |
query_texts=[enhanced_query],
|
| 139 |
n_results=n_results,
|
| 140 |
include=["documents", "metadatas", "distances"],
|
| 141 |
)
|
| 142 |
+
return _format_results(results)
|
| 143 |
except Exception as e:
|
| 144 |
logger.error(f"Error querying resistance patterns: {e}")
|
| 145 |
return []
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def get_context_for_agent(
|
| 149 |
agent_name: str,
|
|
|
|
| 152 |
n_results: int = 3,
|
| 153 |
) -> str:
|
| 154 |
"""
|
| 155 |
+
Return a formatted context string for a specific agent.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
Each agent draws from the collections most relevant to its task:
|
| 158 |
+
- intake_historian: IDSA guidelines
|
| 159 |
+
- vision_specialist: MIC breakpoints
|
| 160 |
+
- trend_analyst: MIC breakpoints + resistance patterns
|
| 161 |
+
- clinical_pharmacologist: guidelines + drug safety
|
| 162 |
"""
|
| 163 |
+
ctx = patient_context or {}
|
| 164 |
+
parts = []
|
| 165 |
|
| 166 |
if agent_name == "intake_historian":
|
| 167 |
+
guidelines = search_antibiotic_guidelines(query, n_results=n_results, pathogen_filter=ctx.get("pathogen_type"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
if guidelines:
|
| 169 |
+
parts.append("RELEVANT TREATMENT GUIDELINES:")
|
| 170 |
for g in guidelines:
|
| 171 |
+
parts.append(f"- {g['content'][:500]}...")
|
| 172 |
+
parts.append(f" [Source: {g.get('source', 'IDSA Guidelines')}]")
|
| 173 |
|
| 174 |
elif agent_name == "vision_specialist":
|
| 175 |
+
breakpoints = search_mic_breakpoints(query, n_results=n_results, organism=ctx.get("organism"), antibiotic=ctx.get("antibiotic"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
if breakpoints:
|
| 177 |
+
parts.append("RELEVANT BREAKPOINT INFORMATION:")
|
| 178 |
for b in breakpoints:
|
| 179 |
+
parts.append(f"- {b['content'][:400]}...")
|
| 180 |
|
| 181 |
elif agent_name == "trend_analyst":
|
|
|
|
| 182 |
breakpoints = search_mic_breakpoints(
|
| 183 |
+
f"breakpoint {ctx.get('organism', '')} {ctx.get('antibiotic', '')}",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
n_results=n_results,
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
+
resistance = search_resistance_patterns(query, n_results=n_results, organism=ctx.get("organism"), region=ctx.get("region"))
|
| 187 |
if breakpoints:
|
| 188 |
+
parts.append("EUCAST BREAKPOINT DATA:")
|
| 189 |
for b in breakpoints:
|
| 190 |
+
parts.append(f"- {b['content'][:400]}...")
|
|
|
|
| 191 |
if resistance:
|
| 192 |
+
parts.append("\nRESISTANCE PATTERN DATA:")
|
| 193 |
for r in resistance:
|
| 194 |
+
parts.append(f"- {r['content'][:400]}...")
|
| 195 |
|
| 196 |
elif agent_name == "clinical_pharmacologist":
|
| 197 |
+
guidelines = search_antibiotic_guidelines(query, n_results=n_results)
|
| 198 |
+
safety = search_drug_safety(query, n_results=n_results, drug_name=ctx.get("proposed_antibiotic"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
if guidelines:
|
| 200 |
+
parts.append("TREATMENT GUIDELINES:")
|
| 201 |
for g in guidelines:
|
| 202 |
+
parts.append(f"- {g['content'][:400]}...")
|
|
|
|
| 203 |
if safety:
|
| 204 |
+
parts.append("\nDRUG SAFETY INFORMATION:")
|
| 205 |
for s in safety:
|
| 206 |
+
parts.append(f"- {s['content'][:400]}...")
|
| 207 |
|
| 208 |
else:
|
|
|
|
| 209 |
guidelines = search_antibiotic_guidelines(query, n_results=n_results)
|
| 210 |
+
for g in guidelines:
|
| 211 |
+
parts.append(f"- {g['content'][:500]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
return "\n".join(parts) if parts else "No relevant context found in knowledge base."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 217 |
+
"""Flatten ChromaDB query results into a list of dicts."""
|
| 218 |
if not results or not results.get("documents"):
|
| 219 |
return []
|
| 220 |
|
|
|
|
| 221 |
documents = results["documents"][0] if results["documents"] else []
|
| 222 |
metadatas = results.get("metadatas", [[]])[0]
|
| 223 |
distances = results.get("distances", [[]])[0]
|
| 224 |
|
| 225 |
+
return [
|
| 226 |
+
{
|
| 227 |
"content": doc,
|
| 228 |
"metadata": metadatas[i] if i < len(metadatas) else {},
|
| 229 |
"distance": distances[i] if i < len(distances) else None,
|
| 230 |
"source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
|
| 231 |
"relevance_score": 1 - (distances[i] if i < len(distances) else 0),
|
| 232 |
+
}
|
| 233 |
+
for i, doc in enumerate(documents)
|
| 234 |
+
]
|
| 235 |
|
| 236 |
|
| 237 |
def list_available_collections() -> List[str]:
|
| 238 |
+
"""Return names of all ChromaDB collections that exist."""
|
|
|
|
| 239 |
try:
|
| 240 |
+
return [c.name for c in get_chroma_client().list_collections()]
|
|
|
|
| 241 |
except Exception as e:
|
| 242 |
logger.error(f"Error listing collections: {e}")
|
| 243 |
return []
|
| 244 |
|
| 245 |
|
| 246 |
def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
|
| 247 |
+
"""Return count and metadata for a collection, or None if it does not exist."""
|
| 248 |
collection = get_collection(name)
|
| 249 |
if collection is None:
|
| 250 |
return None
|
|
|
|
| 251 |
try:
|
| 252 |
+
return {"name": collection.name, "count": collection.count(), "metadata": collection.metadata}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
except Exception as e:
|
| 254 |
logger.error(f"Error getting collection info: {e}")
|
| 255 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/state.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
from typing import Dict, List, Literal, NotRequired, Optional, TypedDict
|
| 5 |
|
| 6 |
|
| 7 |
class LabResult(TypedDict, total=False):
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
name: str
|
| 11 |
value: str
|
| 12 |
unit: NotRequired[Optional[str]]
|
|
@@ -15,21 +12,19 @@ class LabResult(TypedDict, total=False):
|
|
| 15 |
|
| 16 |
|
| 17 |
class MICDatum(TypedDict, total=False):
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
organism: str
|
| 21 |
antibiotic: str
|
| 22 |
mic_value: str
|
| 23 |
mic_unit: NotRequired[Optional[str]]
|
| 24 |
interpretation: NotRequired[Optional[Literal["S", "I", "R"]]]
|
| 25 |
-
breakpoint_source: NotRequired[Optional[str]] # e.g. EUCAST v16.0
|
| 26 |
year: NotRequired[Optional[int]]
|
| 27 |
-
site: NotRequired[Optional[str]] # e.g. blood, urine
|
| 28 |
|
| 29 |
|
| 30 |
class Recommendation(TypedDict, total=False):
|
| 31 |
"""Final clinical recommendation assembled by Agent 4."""
|
| 32 |
-
|
| 33 |
primary_antibiotic: Optional[str]
|
| 34 |
backup_antibiotic: NotRequired[Optional[str]]
|
| 35 |
dose: Optional[str]
|
|
@@ -43,24 +38,19 @@ class Recommendation(TypedDict, total=False):
|
|
| 43 |
|
| 44 |
class InfectionState(TypedDict, total=False):
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
|
| 48 |
-
All
|
| 49 |
-
Most keys are optional to keep the schema flexible across stages.
|
| 50 |
"""
|
| 51 |
|
| 52 |
-
# ------------------------------------------------------------------
|
| 53 |
# Patient identity & demographics
|
| 54 |
-
# ------------------------------------------------------------------
|
| 55 |
patient_id: NotRequired[Optional[str]]
|
| 56 |
age_years: NotRequired[Optional[float]]
|
| 57 |
sex: NotRequired[Optional[Literal["male", "female", "other", "unknown"]]]
|
| 58 |
weight_kg: NotRequired[Optional[float]]
|
| 59 |
height_cm: NotRequired[Optional[float]]
|
| 60 |
|
| 61 |
-
# ------------------------------------------------------------------
|
| 62 |
# Clinical context
|
| 63 |
-
# ------------------------------------------------------------------
|
| 64 |
suspected_source: NotRequired[Optional[str]] # e.g. "community UTI"
|
| 65 |
comorbidities: NotRequired[List[str]]
|
| 66 |
medications: NotRequired[List[str]]
|
|
@@ -68,58 +58,36 @@ class InfectionState(TypedDict, total=False):
|
|
| 68 |
infection_site: NotRequired[Optional[str]]
|
| 69 |
country_or_region: NotRequired[Optional[str]]
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
# Renal function / vitals
|
| 73 |
-
# ------------------------------------------------------------------
|
| 74 |
serum_creatinine_mg_dl: NotRequired[Optional[float]]
|
| 75 |
creatinine_clearance_ml_min: NotRequired[Optional[float]]
|
| 76 |
vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
|
| 77 |
|
| 78 |
-
# ------------------------------------------------------------------
|
| 79 |
# Lab data & MICs
|
| 80 |
-
#
|
| 81 |
-
labs_raw_text: NotRequired[Optional[str]] # raw OCR / PDF text
|
| 82 |
labs_parsed: NotRequired[List[LabResult]]
|
| 83 |
-
|
| 84 |
mic_data: NotRequired[List[MICDatum]]
|
| 85 |
mic_trend_summary: NotRequired[Optional[str]]
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
# Stage / routing metadata
|
| 89 |
-
# ------------------------------------------------------------------
|
| 90 |
stage: NotRequired[Literal["empirical", "targeted"]]
|
| 91 |
route_to_vision: NotRequired[bool]
|
| 92 |
route_to_trend_analyst: NotRequired[bool]
|
| 93 |
|
| 94 |
-
# ------------------------------------------------------------------
|
| 95 |
# Agent outputs
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
pharmacology_notes: NotRequired[Optional[str]] # Agent 4
|
| 101 |
-
|
| 102 |
recommendation: NotRequired[Optional[Recommendation]]
|
| 103 |
|
| 104 |
-
#
|
| 105 |
-
# RAG / context + safety
|
| 106 |
-
# ------------------------------------------------------------------
|
| 107 |
rag_context: NotRequired[Optional[str]]
|
| 108 |
guideline_sources: NotRequired[List[str]]
|
| 109 |
breakpoint_sources: NotRequired[List[str]]
|
| 110 |
safety_warnings: NotRequired[List[str]]
|
| 111 |
|
| 112 |
-
#
|
| 113 |
-
# Diagnostics / debugging
|
| 114 |
-
# ------------------------------------------------------------------
|
| 115 |
errors: NotRequired[List[str]]
|
| 116 |
debug_log: NotRequired[List[str]]
|
| 117 |
|
| 118 |
-
|
| 119 |
-
__all__ = [
|
| 120 |
-
"LabResult",
|
| 121 |
-
"MICDatum",
|
| 122 |
-
"Recommendation",
|
| 123 |
-
"InfectionState",
|
| 124 |
-
]
|
| 125 |
-
|
|
|
|
| 1 |
|
|
|
|
|
|
|
| 2 |
from typing import Dict, List, Literal, NotRequired, Optional, TypedDict
|
| 3 |
|
| 4 |
|
| 5 |
class LabResult(TypedDict, total=False):
|
| 6 |
+
"""A single lab value with optional reference range and flag."""
|
|
|
|
| 7 |
name: str
|
| 8 |
value: str
|
| 9 |
unit: NotRequired[Optional[str]]
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class MICDatum(TypedDict, total=False):
|
| 15 |
+
"""A single MIC measurement for one organism–antibiotic pair."""
|
|
|
|
| 16 |
organism: str
|
| 17 |
antibiotic: str
|
| 18 |
mic_value: str
|
| 19 |
mic_unit: NotRequired[Optional[str]]
|
| 20 |
interpretation: NotRequired[Optional[Literal["S", "I", "R"]]]
|
| 21 |
+
breakpoint_source: NotRequired[Optional[str]] # e.g. "EUCAST v16.0"
|
| 22 |
year: NotRequired[Optional[int]]
|
| 23 |
+
site: NotRequired[Optional[str]] # e.g. "blood", "urine"
|
| 24 |
|
| 25 |
|
| 26 |
class Recommendation(TypedDict, total=False):
|
| 27 |
"""Final clinical recommendation assembled by Agent 4."""
|
|
|
|
| 28 |
primary_antibiotic: Optional[str]
|
| 29 |
backup_antibiotic: NotRequired[Optional[str]]
|
| 30 |
dose: Optional[str]
|
|
|
|
| 38 |
|
| 39 |
class InfectionState(TypedDict, total=False):
|
| 40 |
"""
|
| 41 |
+
Shared state object passed between all agents in the pipeline.
|
| 42 |
|
| 43 |
+
All keys are optional so each agent only needs to populate its own outputs.
|
|
|
|
| 44 |
"""
|
| 45 |
|
|
|
|
| 46 |
# Patient identity & demographics
|
|
|
|
| 47 |
patient_id: NotRequired[Optional[str]]
|
| 48 |
age_years: NotRequired[Optional[float]]
|
| 49 |
sex: NotRequired[Optional[Literal["male", "female", "other", "unknown"]]]
|
| 50 |
weight_kg: NotRequired[Optional[float]]
|
| 51 |
height_cm: NotRequired[Optional[float]]
|
| 52 |
|
|
|
|
| 53 |
# Clinical context
|
|
|
|
| 54 |
suspected_source: NotRequired[Optional[str]] # e.g. "community UTI"
|
| 55 |
comorbidities: NotRequired[List[str]]
|
| 56 |
medications: NotRequired[List[str]]
|
|
|
|
| 58 |
infection_site: NotRequired[Optional[str]]
|
| 59 |
country_or_region: NotRequired[Optional[str]]
|
| 60 |
|
| 61 |
+
# Renal function & vitals
|
|
|
|
|
|
|
| 62 |
serum_creatinine_mg_dl: NotRequired[Optional[float]]
|
| 63 |
creatinine_clearance_ml_min: NotRequired[Optional[float]]
|
| 64 |
vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
|
| 65 |
|
|
|
|
| 66 |
# Lab data & MICs
|
| 67 |
+
labs_raw_text: NotRequired[Optional[str]] # raw OCR or PDF text
|
|
|
|
| 68 |
labs_parsed: NotRequired[List[LabResult]]
|
|
|
|
| 69 |
mic_data: NotRequired[List[MICDatum]]
|
| 70 |
mic_trend_summary: NotRequired[Optional[str]]
|
| 71 |
|
| 72 |
+
# Routing flags set by agents
|
|
|
|
|
|
|
| 73 |
stage: NotRequired[Literal["empirical", "targeted"]]
|
| 74 |
route_to_vision: NotRequired[bool]
|
| 75 |
route_to_trend_analyst: NotRequired[bool]
|
| 76 |
|
|
|
|
| 77 |
# Agent outputs
|
| 78 |
+
intake_notes: NotRequired[Optional[str]] # Agent 1
|
| 79 |
+
vision_notes: NotRequired[Optional[str]] # Agent 2
|
| 80 |
+
trend_notes: NotRequired[Optional[str]] # Agent 3
|
| 81 |
+
pharmacology_notes: NotRequired[Optional[str]] # Agent 4
|
|
|
|
|
|
|
| 82 |
recommendation: NotRequired[Optional[Recommendation]]
|
| 83 |
|
| 84 |
+
# RAG context & safety
|
|
|
|
|
|
|
| 85 |
rag_context: NotRequired[Optional[str]]
|
| 86 |
guideline_sources: NotRequired[List[str]]
|
| 87 |
breakpoint_sources: NotRequired[List[str]]
|
| 88 |
safety_warnings: NotRequired[List[str]]
|
| 89 |
|
| 90 |
+
# Diagnostics
|
|
|
|
|
|
|
| 91 |
errors: NotRequired[List[str]]
|
| 92 |
debug_log: NotRequired[List[str]]
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/tools/rag_tools.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
"""RAG tools for querying clinical guidelines via ChromaDB."""
|
| 2 |
|
| 3 |
-
from typing import Optional
|
| 4 |
from src.db.vector_store import search_guidelines, search_mic_reference
|
| 5 |
|
| 6 |
|
|
|
|
| 1 |
"""RAG tools for querying clinical guidelines via ChromaDB."""
|
| 2 |
|
|
|
|
| 3 |
from src.db.vector_store import search_guidelines, search_mic_reference
|
| 4 |
|
| 5 |
|
src/utils.py
CHANGED
|
@@ -1,23 +1,19 @@
|
|
| 1 |
"""
|
| 2 |
-
Utility functions for
|
| 3 |
|
| 4 |
-
|
| 5 |
-
- Creatinine Clearance (CrCl) calculator
|
| 6 |
- MIC trend analysis and creep detection
|
| 7 |
- Prescription card formatter
|
| 8 |
-
-
|
| 9 |
"""
|
| 10 |
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
import json
|
| 14 |
import math
|
|
|
|
| 15 |
from typing import Any, Dict, List, Literal, Optional, Tuple
|
| 16 |
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
# CREATININE CLEARANCE CALCULATOR
|
| 20 |
-
# =============================================================================
|
| 21 |
|
| 22 |
def calculate_crcl(
|
| 23 |
age_years: float,
|
|
@@ -28,40 +24,25 @@ def calculate_crcl(
|
|
| 28 |
height_cm: Optional[float] = None,
|
| 29 |
) -> float:
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
Formula:
|
| 34 |
-
CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
age_years: Patient age in years
|
| 38 |
-
weight_kg: Actual body weight in kg
|
| 39 |
-
serum_creatinine_mg_dl: Serum creatinine in mg/dL
|
| 40 |
-
sex: Patient sex ("male" or "female")
|
| 41 |
-
use_ibw: If True, use Ideal Body Weight instead of actual weight
|
| 42 |
-
height_cm: Height in cm (required if use_ibw=True)
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
"""
|
| 47 |
if serum_creatinine_mg_dl <= 0:
|
| 48 |
raise ValueError("Serum creatinine must be positive")
|
| 49 |
-
|
| 50 |
if age_years <= 0 or weight_kg <= 0:
|
| 51 |
raise ValueError("Age and weight must be positive")
|
| 52 |
|
| 53 |
-
# Calculate weight to use
|
| 54 |
weight = weight_kg
|
| 55 |
if use_ibw and height_cm:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
if weight_kg > weight * 1.3:
|
| 59 |
-
weight = calculate_adjusted_bw(weight, weight_kg)
|
| 60 |
|
| 61 |
-
# Cockcroft-Gault equation
|
| 62 |
crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
|
| 63 |
-
|
| 64 |
-
# Apply sex factor
|
| 65 |
if sex == "female":
|
| 66 |
crcl *= 0.85
|
| 67 |
|
|
@@ -70,42 +51,27 @@ def calculate_crcl(
|
|
| 70 |
|
| 71 |
def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
|
| 72 |
"""
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
Args:
|
| 76 |
-
height_cm: Height in centimeters
|
| 77 |
-
sex: Patient sex
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
"""
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
if sex == "male":
|
| 86 |
-
ibw = 50 + 2.3 * height_over_60
|
| 87 |
-
else:
|
| 88 |
-
ibw = 45.5 + 2.3 * height_over_60
|
| 89 |
-
|
| 90 |
-
return round(ibw, 1)
|
| 91 |
|
| 92 |
|
| 93 |
def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
|
| 94 |
"""
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
"""
|
| 99 |
return round(ibw + 0.4 * (actual_weight - ibw), 1)
|
| 100 |
|
| 101 |
|
| 102 |
def get_renal_dose_category(crcl: float) -> str:
|
| 103 |
-
"""
|
| 104 |
-
Categorize renal function for dosing purposes.
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
-
Renal function category
|
| 108 |
-
"""
|
| 109 |
if crcl >= 90:
|
| 110 |
return "normal"
|
| 111 |
elif crcl >= 60:
|
|
@@ -118,9 +84,7 @@ def get_renal_dose_category(crcl: float) -> str:
|
|
| 118 |
return "esrd"
|
| 119 |
|
| 120 |
|
| 121 |
-
#
|
| 122 |
-
# MIC TREND ANALYSIS
|
| 123 |
-
# =============================================================================
|
| 124 |
|
| 125 |
def calculate_mic_trend(
|
| 126 |
mic_values: List[Dict[str, Any]],
|
|
@@ -128,15 +92,11 @@ def calculate_mic_trend(
|
|
| 128 |
resistant_breakpoint: Optional[float] = None,
|
| 129 |
) -> Dict[str, Any]:
|
| 130 |
"""
|
| 131 |
-
Analyze
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
resistant_breakpoint: R breakpoint (optional)
|
| 137 |
-
|
| 138 |
-
Returns:
|
| 139 |
-
Dict with trend analysis results
|
| 140 |
"""
|
| 141 |
if len(mic_values) < 2:
|
| 142 |
return {
|
|
@@ -145,49 +105,28 @@ def calculate_mic_trend(
|
|
| 145 |
"alert": "Need at least 2 MIC values for trend analysis",
|
| 146 |
}
|
| 147 |
|
| 148 |
-
# Extract MIC values
|
| 149 |
mics = [float(v["mic_value"]) for v in mic_values]
|
| 150 |
-
|
| 151 |
baseline_mic = mics[0]
|
| 152 |
current_mic = mics[-1]
|
|
|
|
| 153 |
|
| 154 |
-
# Calculate fold change
|
| 155 |
-
if baseline_mic > 0:
|
| 156 |
-
fold_change = current_mic / baseline_mic
|
| 157 |
-
else:
|
| 158 |
-
fold_change = float("inf")
|
| 159 |
-
|
| 160 |
-
# Calculate trend
|
| 161 |
if len(mics) >= 3:
|
| 162 |
-
# Linear regression slope
|
| 163 |
n = len(mics)
|
| 164 |
x_mean = (n - 1) / 2
|
| 165 |
y_mean = sum(mics) / n
|
| 166 |
numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
|
| 167 |
denominator = sum((i - x_mean) ** 2 for i in range(n))
|
| 168 |
slope = numerator / denominator if denominator != 0 else 0
|
| 169 |
-
|
| 170 |
-
if slope > 0.5:
|
| 171 |
-
trend = "increasing"
|
| 172 |
-
elif slope < -0.5:
|
| 173 |
-
trend = "decreasing"
|
| 174 |
-
else:
|
| 175 |
-
trend = "stable"
|
| 176 |
else:
|
| 177 |
-
if current_mic > baseline_mic * 1.5
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
trend = "decreasing"
|
| 181 |
-
else:
|
| 182 |
-
trend = "stable"
|
| 183 |
-
|
| 184 |
-
# Calculate resistance velocity (fold change per time point)
|
| 185 |
velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
|
| 186 |
|
| 187 |
-
# Determine risk level
|
| 188 |
risk_level, alert = _assess_mic_risk(
|
| 189 |
current_mic, baseline_mic, fold_change, trend,
|
| 190 |
-
susceptible_breakpoint, resistant_breakpoint
|
| 191 |
)
|
| 192 |
|
| 193 |
return {
|
|
@@ -211,51 +150,39 @@ def _assess_mic_risk(
|
|
| 211 |
r_breakpoint: Optional[float],
|
| 212 |
) -> Tuple[str, str]:
|
| 213 |
"""
|
| 214 |
-
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
"""
|
| 219 |
-
# If we have breakpoints, use them for risk assessment
|
| 220 |
if s_breakpoint is not None and r_breakpoint is not None:
|
| 221 |
margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
|
| 222 |
|
| 223 |
if current_mic > r_breakpoint:
|
| 224 |
return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
|
| 225 |
-
|
| 226 |
if current_mic > s_breakpoint:
|
| 227 |
return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
|
| 228 |
-
|
| 229 |
if margin < 2:
|
| 230 |
if trend == "increasing":
|
| 231 |
return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
|
| 232 |
-
|
| 233 |
-
return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
|
| 234 |
-
|
| 235 |
if margin < 4:
|
| 236 |
if trend == "increasing":
|
| 237 |
return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
|
| 238 |
-
|
| 239 |
-
return "LOW", "MIC stable with adequate margin to breakpoint."
|
| 240 |
-
|
| 241 |
return "LOW", "MIC well below breakpoint with good safety margin."
|
| 242 |
|
| 243 |
-
#
|
| 244 |
if fold_change >= 8:
|
| 245 |
return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
|
| 246 |
-
|
| 247 |
if fold_change >= 4:
|
| 248 |
return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
|
| 249 |
-
|
| 250 |
if fold_change >= 2:
|
| 251 |
if trend == "increasing":
|
| 252 |
return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
|
| 253 |
-
|
| 254 |
-
return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
|
| 255 |
-
|
| 256 |
if trend == "increasing":
|
| 257 |
return "MODERATE", "MIC showing upward trend. Continue monitoring."
|
| 258 |
-
|
| 259 |
return "LOW", "MIC stable or decreasing. Current therapy appropriate."
|
| 260 |
|
| 261 |
|
|
@@ -268,58 +195,37 @@ def detect_mic_creep(
|
|
| 268 |
"""
|
| 269 |
Detect MIC creep for a specific organism-antibiotic pair.
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
antibiotic: Antibiotic name
|
| 274 |
-
mic_history: Historical MIC values with dates
|
| 275 |
-
breakpoints: Dict with 'susceptible' and 'resistant' keys
|
| 276 |
-
|
| 277 |
-
Returns:
|
| 278 |
-
Comprehensive MIC creep analysis
|
| 279 |
"""
|
| 280 |
-
|
| 281 |
mic_history,
|
| 282 |
susceptible_breakpoint=breakpoints.get("susceptible"),
|
| 283 |
resistant_breakpoint=breakpoints.get("resistant"),
|
| 284 |
)
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
trend_analysis["breakpoint_resistant"] = breakpoints.get("resistant")
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
if
|
| 294 |
-
current =
|
| 295 |
s_bp = breakpoints.get("susceptible")
|
| 296 |
if s_bp and current < s_bp:
|
| 297 |
-
# Estimate doublings needed to reach breakpoint
|
| 298 |
doublings_needed = math.log2(s_bp / current) if current > 0 else 0
|
| 299 |
-
|
| 300 |
-
if
|
| 301 |
-
|
| 302 |
-
if log_velocity > 0:
|
| 303 |
-
time_estimate = doublings_needed / log_velocity
|
| 304 |
-
trend_analysis["estimated_readings_to_resistance"] = round(time_estimate, 1)
|
| 305 |
|
| 306 |
-
return
|
| 307 |
|
| 308 |
|
| 309 |
-
#
|
| 310 |
-
# PRESCRIPTION FORMATTER
|
| 311 |
-
# =============================================================================
|
| 312 |
|
| 313 |
def format_prescription_card(recommendation: Dict[str, Any]) -> str:
|
| 314 |
-
"""
|
| 315 |
-
Format a recommendation into a readable prescription card.
|
| 316 |
-
|
| 317 |
-
Args:
|
| 318 |
-
recommendation: Dict with recommendation details
|
| 319 |
-
|
| 320 |
-
Returns:
|
| 321 |
-
Formatted prescription card as string
|
| 322 |
-
"""
|
| 323 |
lines = []
|
| 324 |
lines.append("=" * 50)
|
| 325 |
lines.append("ANTIBIOTIC PRESCRIPTION")
|
|
@@ -336,14 +242,12 @@ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
|
|
| 336 |
if primary.get("aware_category"):
|
| 337 |
lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
|
| 338 |
|
| 339 |
-
# Dose adjustments
|
| 340 |
adjustments = recommendation.get("dose_adjustments", {})
|
| 341 |
if adjustments.get("renal") and adjustments["renal"] != "None needed":
|
| 342 |
lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
|
| 343 |
if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
|
| 344 |
lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
|
| 345 |
|
| 346 |
-
# Safety alerts
|
| 347 |
alerts = recommendation.get("safety_alerts", [])
|
| 348 |
if alerts:
|
| 349 |
lines.append("\n" + "-" * 50)
|
|
@@ -353,7 +257,6 @@ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
|
|
| 353 |
marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
|
| 354 |
lines.append(f" {marker} {alert.get('message', '')}")
|
| 355 |
|
| 356 |
-
# Monitoring
|
| 357 |
monitoring = recommendation.get("monitoring_parameters", [])
|
| 358 |
if monitoring:
|
| 359 |
lines.append("\n" + "-" * 50)
|
|
@@ -361,47 +264,33 @@ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
|
|
| 361 |
for param in monitoring:
|
| 362 |
lines.append(f" - {param}")
|
| 363 |
|
| 364 |
-
# Rationale
|
| 365 |
if recommendation.get("rationale"):
|
| 366 |
lines.append("\n" + "-" * 50)
|
| 367 |
lines.append("RATIONALE:")
|
| 368 |
lines.append(f" {recommendation['rationale']}")
|
| 369 |
|
| 370 |
lines.append("\n" + "=" * 50)
|
| 371 |
-
|
| 372 |
return "\n".join(lines)
|
| 373 |
|
| 374 |
|
| 375 |
-
#
|
| 376 |
-
# JSON PARSING HELPERS
|
| 377 |
-
# =============================================================================
|
| 378 |
|
| 379 |
def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
|
| 380 |
"""
|
| 381 |
-
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
"""
|
| 386 |
if not text:
|
| 387 |
return None
|
| 388 |
|
| 389 |
-
# Try direct parse first
|
| 390 |
try:
|
| 391 |
return json.loads(text)
|
| 392 |
except json.JSONDecodeError:
|
| 393 |
pass
|
| 394 |
|
| 395 |
-
|
| 396 |
-
import re
|
| 397 |
-
|
| 398 |
-
json_patterns = [
|
| 399 |
-
r"```json\s*\n?(.*?)\n?```", # ```json ... ```
|
| 400 |
-
r"```\s*\n?(.*?)\n?```", # ``` ... ```
|
| 401 |
-
r"\{[\s\S]*\}", # Raw JSON object
|
| 402 |
-
]
|
| 403 |
-
|
| 404 |
-
for pattern in json_patterns:
|
| 405 |
match = re.search(pattern, text, re.DOTALL)
|
| 406 |
if match:
|
| 407 |
try:
|
|
@@ -414,29 +303,15 @@ def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
|
|
| 414 |
|
| 415 |
|
| 416 |
def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
|
| 417 |
-
"""
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
Args:
|
| 421 |
-
output: Agent output dict
|
| 422 |
-
required_fields: List of required field names
|
| 423 |
-
|
| 424 |
-
Returns:
|
| 425 |
-
Tuple of (is_valid, list_of_missing_fields)
|
| 426 |
-
"""
|
| 427 |
-
missing = [field for field in required_fields if field not in output]
|
| 428 |
return len(missing) == 0, missing
|
| 429 |
|
| 430 |
|
| 431 |
-
#
|
| 432 |
-
# DATA NORMALIZATION
|
| 433 |
-
# =============================================================================
|
| 434 |
|
| 435 |
def normalize_antibiotic_name(name: str) -> str:
|
| 436 |
-
"""
|
| 437 |
-
Normalize antibiotic name to standard format.
|
| 438 |
-
"""
|
| 439 |
-
# Common name mappings
|
| 440 |
mappings = {
|
| 441 |
"amox": "amoxicillin",
|
| 442 |
"amox/clav": "amoxicillin-clavulanate",
|
|
@@ -459,18 +334,11 @@ def normalize_antibiotic_name(name: str) -> str:
|
|
| 459 |
"cefepime": "cefepime",
|
| 460 |
"maxipime": "cefepime",
|
| 461 |
}
|
| 462 |
-
|
| 463 |
-
normalized = name.lower().strip()
|
| 464 |
-
return mappings.get(normalized, normalized)
|
| 465 |
|
| 466 |
|
| 467 |
def normalize_organism_name(name: str) -> str:
|
| 468 |
-
"""
|
| 469 |
-
Normalize organism name to standard format.
|
| 470 |
-
"""
|
| 471 |
-
name = name.strip()
|
| 472 |
-
|
| 473 |
-
# Common abbreviations
|
| 474 |
abbreviations = {
|
| 475 |
"e. coli": "Escherichia coli",
|
| 476 |
"e.coli": "Escherichia coli",
|
|
@@ -485,21 +353,4 @@ def normalize_organism_name(name: str) -> str:
|
|
| 485 |
"enterococcus": "Enterococcus species",
|
| 486 |
"vre": "Enterococcus (VRE)",
|
| 487 |
}
|
| 488 |
-
|
| 489 |
-
lower_name = name.lower()
|
| 490 |
-
return abbreviations.get(lower_name, name)
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
__all__ = [
|
| 494 |
-
"calculate_crcl",
|
| 495 |
-
"calculate_ibw",
|
| 496 |
-
"calculate_adjusted_bw",
|
| 497 |
-
"get_renal_dose_category",
|
| 498 |
-
"calculate_mic_trend",
|
| 499 |
-
"detect_mic_creep",
|
| 500 |
-
"format_prescription_card",
|
| 501 |
-
"safe_json_parse",
|
| 502 |
-
"validate_agent_output",
|
| 503 |
-
"normalize_antibiotic_name",
|
| 504 |
-
"normalize_organism_name",
|
| 505 |
-
]
|
|
|
|
| 1 |
"""
|
| 2 |
+
Utility functions for clinical calculations and data parsing.
|
| 3 |
|
| 4 |
+
- Creatinine Clearance (CrCl) via Cockcroft-Gault
|
|
|
|
| 5 |
- MIC trend analysis and creep detection
|
| 6 |
- Prescription card formatter
|
| 7 |
+
- JSON parsing and data normalization helpers
|
| 8 |
"""
|
| 9 |
|
|
|
|
|
|
|
| 10 |
import json
|
| 11 |
import math
|
| 12 |
+
import re
|
| 13 |
from typing import Any, Dict, List, Literal, Optional, Tuple
|
| 14 |
|
| 15 |
|
| 16 |
+
# --- CrCl calculator ---
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def calculate_crcl(
|
| 19 |
age_years: float,
|
|
|
|
| 24 |
height_cm: Optional[float] = None,
|
| 25 |
) -> float:
|
| 26 |
"""
|
| 27 |
+
Cockcroft-Gault equation.
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
When use_ibw=True and height is given, uses Ideal Body Weight.
|
| 32 |
+
For obese patients (actual > 1.3 × IBW), switches to Adjusted Body Weight.
|
| 33 |
+
Returns CrCl in mL/min.
|
| 34 |
"""
|
| 35 |
if serum_creatinine_mg_dl <= 0:
|
| 36 |
raise ValueError("Serum creatinine must be positive")
|
|
|
|
| 37 |
if age_years <= 0 or weight_kg <= 0:
|
| 38 |
raise ValueError("Age and weight must be positive")
|
| 39 |
|
|
|
|
| 40 |
weight = weight_kg
|
| 41 |
if use_ibw and height_cm:
|
| 42 |
+
ibw = calculate_ibw(height_cm, sex)
|
| 43 |
+
weight = calculate_adjusted_bw(ibw, weight_kg) if weight_kg > ibw * 1.3 else ibw
|
|
|
|
|
|
|
| 44 |
|
|
|
|
| 45 |
crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
|
|
|
|
|
|
|
| 46 |
if sex == "female":
|
| 47 |
crcl *= 0.85
|
| 48 |
|
|
|
|
| 51 |
|
| 52 |
def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
|
| 53 |
"""
|
| 54 |
+
Devine formula for Ideal Body Weight.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
Male: 50 kg + 2.3 kg per inch over 5 feet
|
| 57 |
+
Female: 45.5 kg + 2.3 kg per inch over 5 feet
|
| 58 |
"""
|
| 59 |
+
height_over_60_inches = max(0, height_cm / 2.54 - 60)
|
| 60 |
+
base = 50 if sex == "male" else 45.5
|
| 61 |
+
return round(base + 2.3 * height_over_60_inches, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
|
| 65 |
"""
|
| 66 |
+
Adjusted Body Weight for obese patients.
|
| 67 |
|
| 68 |
+
AdjBW = IBW + 0.4 × (Actual - IBW)
|
| 69 |
"""
|
| 70 |
return round(ibw + 0.4 * (actual_weight - ibw), 1)
|
| 71 |
|
| 72 |
|
| 73 |
def get_renal_dose_category(crcl: float) -> str:
|
| 74 |
+
"""Map CrCl value to a dosing category string."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
if crcl >= 90:
|
| 76 |
return "normal"
|
| 77 |
elif crcl >= 60:
|
|
|
|
| 84 |
return "esrd"
|
| 85 |
|
| 86 |
|
| 87 |
+
# --- MIC trend analysis ---
|
|
|
|
|
|
|
| 88 |
|
| 89 |
def calculate_mic_trend(
|
| 90 |
mic_values: List[Dict[str, Any]],
|
|
|
|
| 92 |
resistant_breakpoint: Optional[float] = None,
|
| 93 |
) -> Dict[str, Any]:
|
| 94 |
"""
|
| 95 |
+
Analyze a list of MIC readings over time.
|
| 96 |
|
| 97 |
+
Requires at least 2 readings. Uses linear regression slope for trend
|
| 98 |
+
direction when >= 3 points are available; falls back to ratio comparison
|
| 99 |
+
for exactly 2 points.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
"""
|
| 101 |
if len(mic_values) < 2:
|
| 102 |
return {
|
|
|
|
| 105 |
"alert": "Need at least 2 MIC values for trend analysis",
|
| 106 |
}
|
| 107 |
|
|
|
|
| 108 |
mics = [float(v["mic_value"]) for v in mic_values]
|
|
|
|
| 109 |
baseline_mic = mics[0]
|
| 110 |
current_mic = mics[-1]
|
| 111 |
+
fold_change = (current_mic / baseline_mic) if baseline_mic > 0 else float("inf")
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if len(mics) >= 3:
|
|
|
|
| 114 |
n = len(mics)
|
| 115 |
x_mean = (n - 1) / 2
|
| 116 |
y_mean = sum(mics) / n
|
| 117 |
numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
|
| 118 |
denominator = sum((i - x_mean) ** 2 for i in range(n))
|
| 119 |
slope = numerator / denominator if denominator != 0 else 0
|
| 120 |
+
trend = "increasing" if slope > 0.5 else "decreasing" if slope < -0.5 else "stable"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
else:
|
| 122 |
+
trend = "increasing" if current_mic > baseline_mic * 1.5 else "decreasing" if current_mic < baseline_mic * 0.67 else "stable"
|
| 123 |
+
|
| 124 |
+
# Fold change per time step (geometric rate of change)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
|
| 126 |
|
|
|
|
| 127 |
risk_level, alert = _assess_mic_risk(
|
| 128 |
current_mic, baseline_mic, fold_change, trend,
|
| 129 |
+
susceptible_breakpoint, resistant_breakpoint,
|
| 130 |
)
|
| 131 |
|
| 132 |
return {
|
|
|
|
| 150 |
r_breakpoint: Optional[float],
|
| 151 |
) -> Tuple[str, str]:
|
| 152 |
"""
|
| 153 |
+
Assign a risk level (LOW/MODERATE/HIGH/CRITICAL) based on breakpoints and fold change.
|
| 154 |
|
| 155 |
+
Prefers breakpoint-based assessment when breakpoints are available.
|
| 156 |
+
Falls back to fold-change thresholds otherwise.
|
| 157 |
"""
|
|
|
|
| 158 |
if s_breakpoint is not None and r_breakpoint is not None:
|
| 159 |
margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
|
| 160 |
|
| 161 |
if current_mic > r_breakpoint:
|
| 162 |
return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
|
|
|
|
| 163 |
if current_mic > s_breakpoint:
|
| 164 |
return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
|
|
|
|
| 165 |
if margin < 2:
|
| 166 |
if trend == "increasing":
|
| 167 |
return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
|
| 168 |
+
return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
|
|
|
|
|
|
|
| 169 |
if margin < 4:
|
| 170 |
if trend == "increasing":
|
| 171 |
return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
|
| 172 |
+
return "LOW", "MIC stable with adequate margin to breakpoint."
|
|
|
|
|
|
|
| 173 |
return "LOW", "MIC well below breakpoint with good safety margin."
|
| 174 |
|
| 175 |
+
# No breakpoints — use fold change thresholds from EUCAST MIC creep criteria
|
| 176 |
if fold_change >= 8:
|
| 177 |
return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
|
|
|
|
| 178 |
if fold_change >= 4:
|
| 179 |
return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
|
|
|
|
| 180 |
if fold_change >= 2:
|
| 181 |
if trend == "increasing":
|
| 182 |
return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
|
| 183 |
+
return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
|
|
|
|
|
|
|
| 184 |
if trend == "increasing":
|
| 185 |
return "MODERATE", "MIC showing upward trend. Continue monitoring."
|
|
|
|
| 186 |
return "LOW", "MIC stable or decreasing. Current therapy appropriate."
|
| 187 |
|
| 188 |
|
|
|
|
| 195 |
"""
|
| 196 |
Detect MIC creep for a specific organism-antibiotic pair.
|
| 197 |
|
| 198 |
+
Augments calculate_mic_trend with a time-to-resistance estimate
|
| 199 |
+
when the MIC is rising and a susceptible breakpoint is available.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
"""
|
| 201 |
+
result = calculate_mic_trend(
|
| 202 |
mic_history,
|
| 203 |
susceptible_breakpoint=breakpoints.get("susceptible"),
|
| 204 |
resistant_breakpoint=breakpoints.get("resistant"),
|
| 205 |
)
|
| 206 |
|
| 207 |
+
result["organism"] = organism
|
| 208 |
+
result["antibiotic"] = antibiotic
|
| 209 |
+
result["breakpoint_susceptible"] = breakpoints.get("susceptible")
|
| 210 |
+
result["breakpoint_resistant"] = breakpoints.get("resistant")
|
|
|
|
| 211 |
|
| 212 |
+
# Estimate how many more time-points until MIC reaches the susceptible breakpoint
|
| 213 |
+
if result["trend"] == "increasing" and result["velocity"] > 1.0:
|
| 214 |
+
current = result["current_mic"]
|
| 215 |
s_bp = breakpoints.get("susceptible")
|
| 216 |
if s_bp and current < s_bp:
|
|
|
|
| 217 |
doublings_needed = math.log2(s_bp / current) if current > 0 else 0
|
| 218 |
+
log_velocity = math.log(result["velocity"]) / math.log(2)
|
| 219 |
+
if log_velocity > 0:
|
| 220 |
+
result["estimated_readings_to_resistance"] = round(doublings_needed / log_velocity, 1)
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
+
return result
|
| 223 |
|
| 224 |
|
| 225 |
+
# --- Prescription formatter ---
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def format_prescription_card(recommendation: Dict[str, Any]) -> str:
|
| 228 |
+
"""Format a recommendation dict as a plain-text prescription card."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
lines = []
|
| 230 |
lines.append("=" * 50)
|
| 231 |
lines.append("ANTIBIOTIC PRESCRIPTION")
|
|
|
|
| 242 |
if primary.get("aware_category"):
|
| 243 |
lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
|
| 244 |
|
|
|
|
| 245 |
adjustments = recommendation.get("dose_adjustments", {})
|
| 246 |
if adjustments.get("renal") and adjustments["renal"] != "None needed":
|
| 247 |
lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
|
| 248 |
if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
|
| 249 |
lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
|
| 250 |
|
|
|
|
| 251 |
alerts = recommendation.get("safety_alerts", [])
|
| 252 |
if alerts:
|
| 253 |
lines.append("\n" + "-" * 50)
|
|
|
|
| 257 |
marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
|
| 258 |
lines.append(f" {marker} {alert.get('message', '')}")
|
| 259 |
|
|
|
|
| 260 |
monitoring = recommendation.get("monitoring_parameters", [])
|
| 261 |
if monitoring:
|
| 262 |
lines.append("\n" + "-" * 50)
|
|
|
|
| 264 |
for param in monitoring:
|
| 265 |
lines.append(f" - {param}")
|
| 266 |
|
|
|
|
| 267 |
if recommendation.get("rationale"):
|
| 268 |
lines.append("\n" + "-" * 50)
|
| 269 |
lines.append("RATIONALE:")
|
| 270 |
lines.append(f" {recommendation['rationale']}")
|
| 271 |
|
| 272 |
lines.append("\n" + "=" * 50)
|
|
|
|
| 273 |
return "\n".join(lines)
|
| 274 |
|
| 275 |
|
| 276 |
+
# --- JSON parsing ---
|
|
|
|
|
|
|
| 277 |
|
| 278 |
def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
|
| 279 |
"""
|
| 280 |
+
Extract and parse the first JSON object from a string.
|
| 281 |
|
| 282 |
+
Handles model output that may wrap JSON in markdown code fences.
|
| 283 |
+
Returns None if no valid JSON is found.
|
| 284 |
"""
|
| 285 |
if not text:
|
| 286 |
return None
|
| 287 |
|
|
|
|
| 288 |
try:
|
| 289 |
return json.loads(text)
|
| 290 |
except json.JSONDecodeError:
|
| 291 |
pass
|
| 292 |
|
| 293 |
+
for pattern in [r"```json\s*\n?(.*?)\n?```", r"```\s*\n?(.*?)\n?```", r"\{[\s\S]*\}"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
match = re.search(pattern, text, re.DOTALL)
|
| 295 |
if match:
|
| 296 |
try:
|
|
|
|
| 303 |
|
| 304 |
|
| 305 |
def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
|
| 306 |
+
"""Return (is_valid, missing_fields) for an agent output dict."""
|
| 307 |
+
missing = [f for f in required_fields if f not in output]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
return len(missing) == 0, missing
|
| 309 |
|
| 310 |
|
| 311 |
+
# --- Name normalization ---
|
|
|
|
|
|
|
| 312 |
|
| 313 |
def normalize_antibiotic_name(name: str) -> str:
|
| 314 |
+
"""Map common abbreviations and brand names to standard antibiotic names."""
|
|
|
|
|
|
|
|
|
|
| 315 |
mappings = {
|
| 316 |
"amox": "amoxicillin",
|
| 317 |
"amox/clav": "amoxicillin-clavulanate",
|
|
|
|
| 334 |
"cefepime": "cefepime",
|
| 335 |
"maxipime": "cefepime",
|
| 336 |
}
|
| 337 |
+
return mappings.get(name.lower().strip(), name.lower().strip())
|
|
|
|
|
|
|
| 338 |
|
| 339 |
|
| 340 |
def normalize_organism_name(name: str) -> str:
|
| 341 |
+
"""Map common abbreviations to full organism names."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
abbreviations = {
|
| 343 |
"e. coli": "Escherichia coli",
|
| 344 |
"e.coli": "Escherichia coli",
|
|
|
|
| 353 |
"enterococcus": "Enterococcus species",
|
| 354 |
"vre": "Enterococcus (VRE)",
|
| 355 |
}
|
| 356 |
+
return abbreviations.get(name.strip().lower(), name.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|