ghitaben commited on
Commit
ad9e267
·
1 Parent(s): 4c3f028

Refactor prompt templates and RAG module

Browse files
Files changed (11) hide show
  1. notebooks/kaggle_medic_demo.ipynb +6 -54
  2. src/agents.py +46 -217
  3. src/config.py +15 -71
  4. src/db/import_data.py +113 -164
  5. src/graph.py +36 -205
  6. src/loader.py +20 -92
  7. src/prompts.py +7 -43
  8. src/rag.py +68 -295
  9. src/state.py +15 -47
  10. src/tools/rag_tools.py +0 -1
  11. src/utils.py +67 -216
notebooks/kaggle_medic_demo.ipynb CHANGED
@@ -72,15 +72,7 @@
72
  "id": "205d4ba2",
73
  "metadata": {},
74
  "outputs": [],
75
- "source": [
76
- "%%capture\n",
77
- "!pip install -q \\\n",
78
- " \"langgraph>=0.0.15\" \"langchain>=0.3.0\" langchain-text-splitters langchain-community \\\n",
79
- " \"chromadb>=0.4.0\" sentence-transformers \\\n",
80
- " \"transformers>=4.50.0\" accelerate bitsandbytes \\\n",
81
- " streamlit huggingface_hub \\\n",
82
- " \"pydantic>=2.0\" python-dotenv openpyxl pypdf \"pandas>=2.0\" jq"
83
- ]
84
  },
85
  {
86
  "cell_type": "markdown",
@@ -267,9 +259,7 @@
267
  "cell_type": "markdown",
268
  "id": "37d17f6b",
269
  "metadata": {},
270
- "source": [
271
- "## 5 · Launch the App"
272
- ]
273
  },
274
  {
275
  "cell_type": "code",
@@ -277,10 +267,7 @@
277
  "id": "96ff2d63",
278
  "metadata": {},
279
  "outputs": [],
280
- "source": [
281
- "%%capture\n",
282
- "!pip install -q localtunnel"
283
- ]
284
  },
285
  {
286
  "cell_type": "code",
@@ -288,28 +275,7 @@
288
  "id": "ea6b1788",
289
  "metadata": {},
290
  "outputs": [],
291
- "source": [
292
- "import subprocess, time, requests\n",
293
- "\n",
294
- "streamlit_proc = subprocess.Popen(\n",
295
- " [\"streamlit\", \"run\", \"/kaggle/working/Med-I-C/app.py\",\n",
296
- " \"--server.port\", \"8501\",\n",
297
- " \"--server.headless\", \"true\",\n",
298
- " \"--server.enableCORS\", \"false\"],\n",
299
- " stdout=subprocess.DEVNULL,\n",
300
- " stderr=subprocess.DEVNULL,\n",
301
- ")\n",
302
- "\n",
303
- "for _ in range(15):\n",
304
- " try:\n",
305
- " if requests.get(\"http://localhost:8501\", timeout=2).status_code == 200:\n",
306
- " print(\"Streamlit running on :8501\")\n",
307
- " break\n",
308
- " except Exception:\n",
309
- " time.sleep(2)\n",
310
- "else:\n",
311
- " print(\"Streamlit may still be starting…\")"
312
- ]
313
  },
314
  {
315
  "cell_type": "code",
@@ -317,21 +283,7 @@
317
  "id": "00ecfb17",
318
  "metadata": {},
319
  "outputs": [],
320
- "source": [
321
- "tunnel_proc = subprocess.Popen(\n",
322
- " [\"npx\", \"localtunnel\", \"--port\", \"8501\"],\n",
323
- " stdout=subprocess.PIPE,\n",
324
- " stderr=subprocess.DEVNULL,\n",
325
- " text=True,\n",
326
- ")\n",
327
- "\n",
328
- "for line in tunnel_proc.stdout:\n",
329
- " if \"https://\" in line:\n",
330
- " print(\"\\n\" + \"=\"*50)\n",
331
- " print(f\" App URL: {line.strip()}\")\n",
332
- " print(\"=\"*50)\n",
333
- " break"
334
- ]
335
  }
336
  ],
337
  "metadata": {
@@ -347,4 +299,4 @@
347
  },
348
  "nbformat": 4,
349
  "nbformat_minor": 5
350
- }
 
72
  "id": "205d4ba2",
73
  "metadata": {},
74
  "outputs": [],
75
+ "source": "%%capture\n!pip install -q \\\n \"langgraph>=0.0.15\" \"langchain>=0.3.0\" langchain-text-splitters langchain-community \\\n \"chromadb>=0.4.0\" sentence-transformers \\\n \"transformers>=4.50.0\" accelerate bitsandbytes \\\n gradio huggingface_hub \\\n \"pydantic>=2.0\" python-dotenv openpyxl pypdf \"pandas>=2.0\" jq"
 
 
 
 
 
 
 
 
76
  },
77
  {
78
  "cell_type": "markdown",
 
259
  "cell_type": "markdown",
260
  "id": "37d17f6b",
261
  "metadata": {},
262
+ "source": "## 5 · Launch the Gradio App\n\nTwo tabbed scenarios are exposed in a single Gradio interface:\n\n| Tab | Scenario | Agents active |\n|---|---|---|\n| **Stage 1 — Empirical** | No lab results yet | Agent 1 (MedGemma 4B) → Agent 4 (MedGemma 4B + TxGemma 2B) |\n| **Stage 2 — Targeted** | Culture & sensitivity available | Agent 1 (MedGemma 4B) → Agent 2 (MedGemma 4B) → Agent 3 (MedGemma 4B→27B sub) → Agent 4 (MedGemma 4B + TxGemma 2B) |\n\n`demo.launch(share=True)` prints a public Gradio URL — no extra tunnel needed."
 
 
263
  },
264
  {
265
  "cell_type": "code",
 
267
  "id": "96ff2d63",
268
  "metadata": {},
269
  "outputs": [],
270
+ "source": "import json\nimport sys\n\nsys.path.insert(0, \"/kaggle/working/Med-I-C\")\n\n# ── Demo fallback (used when pipeline errors out before models are warm) ──────\n\ndef _demo_result(patient_data: dict, labs_text) -> dict:\n result = {\n \"stage\": \"targeted\" if labs_text else \"empirical\",\n \"creatinine_clearance_ml_min\": 58.3,\n \"intake_notes\": json.dumps({\n \"patient_summary\": (\n f\"{patient_data.get('age_years')}-year-old {patient_data.get('sex')} · \"\n f\"{patient_data.get('suspected_source', 'infection')}\"\n ),\n \"creatinine_clearance_ml_min\": 58.3,\n \"renal_dose_adjustment_needed\": True,\n \"identified_risk_factors\": patient_data.get(\"comorbidities\", []),\n \"infection_severity\": \"moderate\",\n \"recommended_stage\": \"targeted\" if labs_text else \"empirical\",\n }),\n \"recommendation\": {\n \"primary_antibiotic\": \"Ciprofloxacin\",\n \"dose\": \"500 mg\",\n \"route\": \"Oral\",\n \"frequency\": \"Every 12 hours\",\n \"duration\": \"7 days\",\n \"backup_antibiotic\": \"Nitrofurantoin 100 mg MR BD × 5 days\",\n \"rationale\": (\n \"Community-acquired UTI with moderate renal impairment (CrCl 58 mL/min). \"\n \"Ciprofloxacin provides broad Gram-negative coverage. No dose adjustment \"\n \"required above CrCl 30 mL/min.\"\n ),\n \"references\": [\"IDSA UTI Guidelines 2024\", \"EUCAST Breakpoint Tables v16.0\"],\n },\n \"safety_warnings\": [],\n \"errors\": [],\n }\n if labs_text:\n result[\"vision_notes\"] = json.dumps({\n \"specimen_type\": \"urine\",\n \"identified_organisms\": [{\"organism_name\": \"Escherichia coli\", \"significance\": \"pathogen\"}],\n \"susceptibility_results\": [\n {\"organism\": \"E. coli\", \"antibiotic\": \"Ciprofloxacin\", \"mic_value\": 0.25, \"interpretation\": \"S\"},\n {\"organism\": \"E. coli\", \"antibiotic\": \"Nitrofurantoin\", \"mic_value\": 16, \"interpretation\": \"S\"},\n {\"organism\": \"E. coli\", \"antibiotic\": \"Ampicillin\", \"mic_value\": \">32\", \"interpretation\": \"R\"},\n ],\n \"extraction_confidence\": 0.95,\n })\n result[\"trend_notes\"] = json.dumps([{\n \"organism\": \"E. coli\",\n \"antibiotic\": \"Ciprofloxacin\",\n \"risk_level\": \"LOW\",\n \"recommendation\": \"No MIC creep detected — continue current therapy.\",\n }])\n return result\n\n\n# ── Output formatters ─────────────────────────────────────────────────────────\n\ndef _parse_json_field(raw):\n if not raw or raw in (\"No lab data provided\", \"No MIC data available for trend analysis\", \"\"):\n return None\n if isinstance(raw, (dict, list)):\n return raw\n try:\n return json.loads(raw)\n except Exception:\n return None\n\n\ndef format_recommendation(result: dict) -> str:\n lines = [\"## ℞ Recommendation\\n\"]\n rec = result.get(\"recommendation\", {})\n if rec:\n drug = rec.get(\"primary_antibiotic\", \"—\")\n dose = rec.get(\"dose\", \"—\")\n route = rec.get(\"route\", \"—\")\n freq = rec.get(\"frequency\", \"—\")\n dur = rec.get(\"duration\", \"—\")\n lines.append(f\"**Drug:** {drug}\")\n lines.append(\n f\"**Dose:** {dose}  ·  **Route:** {route} \"\n f\" ·  **Frequency:** {freq}  ·  **Duration:** {dur}\"\n )\n if rec.get(\"backup_antibiotic\"):\n lines.append(f\"**Alternative:** {rec['backup_antibiotic']}\")\n if rec.get(\"rationale\"):\n lines.append(f\"\\n**Clinical rationale:** {rec['rationale']}\")\n if rec.get(\"references\"):\n lines.append(\"\\n**References:**\")\n for ref in rec[\"references\"]:\n lines.append(f\"- {ref}\")\n\n intake = _parse_json_field(result.get(\"intake_notes\", \"\"))\n if isinstance(intake, dict):\n lines.append(\"\\n---\\n## Patient Summary\")\n if intake.get(\"patient_summary\"):\n lines.append(f\"> {intake['patient_summary']}\")\n crcl = result.get(\"creatinine_clearance_ml_min\") or intake.get(\"creatinine_clearance_ml_min\")\n if crcl:\n lines.append(f\"**CrCl:** {float(crcl):.1f} mL/min\")\n if intake.get(\"renal_dose_adjustment_needed\"):\n lines.append(\"⚠ **Renal dose adjustment required**\")\n factors = intake.get(\"identified_risk_factors\", [])\n if factors:\n lines.append(f\"**Risk factors:** {', '.join(factors)}\")\n\n warnings = result.get(\"safety_warnings\", [])\n if warnings:\n lines.append(\"\\n---\\n## ⚠ Safety Warnings\")\n for w in warnings:\n lines.append(f\"- {w}\")\n\n errors = result.get(\"errors\", [])\n if errors:\n lines.append(\"\\n---\\n## Errors\")\n for e in errors:\n lines.append(f\"- {e}\")\n\n return \"\\n\".join(lines)\n\n\ndef format_lab_analysis(result: dict) -> str:\n lines = []\n vision = _parse_json_field(result.get(\"vision_notes\", \"\"))\n trend = _parse_json_field(result.get(\"trend_notes\", \"\"))\n\n if vision is None:\n return \"*No lab data processed.*\"\n\n if isinstance(vision, dict):\n lines.append(\"## Lab Extraction\")\n if vision.get(\"specimen_type\"):\n lines.append(f\"**Specimen:** {vision['specimen_type'].capitalize()}\")\n if vision.get(\"extraction_confidence\") is not None:\n conf = float(vision[\"extraction_confidence\"])\n lines.append(f\"**Extraction confidence:** {conf:.0%}\")\n\n orgs = vision.get(\"identified_organisms\", [])\n if orgs:\n lines.append(\"\\n**Identified organisms:**\")\n for o in orgs:\n name = o.get(\"organism_name\", \"Unknown\")\n sig = o.get(\"significance\", \"\")\n lines.append(f\"- **{name}**\" + (f\" — {sig}\" if sig else \"\"))\n\n sus = vision.get(\"susceptibility_results\", [])\n if sus:\n lines.append(\"\\n**Susceptibility results:**\")\n lines.append(\"| Organism | Antibiotic | MIC (mg/L) | Result |\")\n lines.append(\"|---|---|---|---|\")\n for s in sus:\n interp = s.get(\"interpretation\", \"\")\n icon = {\"S\": \"✓ S\", \"R\": \"✗ R\", \"I\": \"~ I\"}.get(interp.upper(), interp)\n lines.append(\n f\"| {s.get('organism','')} | {s.get('antibiotic','')} \"\n f\"| {s.get('mic_value','')} | {icon} |\"\n )\n\n if trend:\n items = trend if isinstance(trend, list) else [trend]\n lines.append(\"\\n## MIC Trend Analysis\")\n for item in items:\n if not isinstance(item, dict):\n lines.append(str(item))\n continue\n risk = item.get(\"risk_level\", \"UNKNOWN\").upper()\n icon = {\"HIGH\": \"🚨\", \"MODERATE\": \"⚠\"}.get(risk, \"✓\")\n org = item.get(\"organism\", \"\")\n ab = item.get(\"antibiotic\", \"\")\n label = f\"{org} / {ab} — \" if (org or ab) else \"\"\n lines.append(f\"**{icon} {label}{risk}** — {item.get('recommendation', '')}\")\n\n return \"\\n\".join(lines) if lines else \"*No lab analysis available.*\"\n\n\n# ── Pipeline runner helpers ───────────────────────────────────────────────────\n\ndef _build_patient_data(age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str):\n return {\n \"age_years\": float(age),\n \"weight_kg\": float(weight),\n \"height_cm\": float(height),\n \"sex\": sex,\n \"serum_creatinine_mg_dl\": float(creatinine),\n \"infection_site\": infection_site,\n \"suspected_source\": suspected_source or f\"{infection_site} infection\",\n \"medications\": [m.strip() for m in medications_str.split(\"\\n\") if m.strip()],\n \"allergies\": [a.strip() for a in allergies_str.split(\"\\n\") if a.strip()],\n \"comorbidities\":[c.strip() for c in comorbidities_str.split(\"\\n\") if c.strip()],\n }\n\n\ndef run_empirical_scenario(age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str):\n \"\"\"Stage 1 — Empirical: no lab results.\n Active models: MedGemma 4B (Agent 1) → MedGemma 4B + TxGemma 2B (Agent 4).\n \"\"\"\n patient_data = _build_patient_data(\n age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str,\n )\n try:\n from src.graph import run_pipeline\n result = run_pipeline(patient_data, labs_raw_text=None)\n except Exception as exc:\n result = _demo_result(patient_data, None)\n result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n return format_recommendation(result)\n\n\ndef run_targeted_scenario(age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str,\n labs_text):\n \"\"\"Stage 2 — Targeted: lab culture & sensitivity available.\n Active models: MedGemma 4B (Agents 1, 2) → MedGemma 4B→27B sub (Agent 3)\n → MedGemma 4B + TxGemma 2B (Agent 4).\n \"\"\"\n patient_data = _build_patient_data(\n age, weight, height, sex, creatinine,\n infection_site, suspected_source,\n medications_str, allergies_str, comorbidities_str,\n )\n labs = labs_text.strip() if labs_text else None\n try:\n from src.graph import run_pipeline\n result = run_pipeline(patient_data, labs_raw_text=labs)\n except Exception as exc:\n result = _demo_result(patient_data, labs)\n result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n return format_recommendation(result), format_lab_analysis(result)\n\n\nprint(\"Helper functions loaded.\")"
 
 
 
271
  },
272
  {
273
  "cell_type": "code",
 
275
  "id": "ea6b1788",
276
  "metadata": {},
277
  "outputs": [],
278
+ "source": "import gradio as gr\n\nINFECTION_SITES = [\"urinary\", \"respiratory\", \"bloodstream\", \"skin\", \"intra-abdominal\", \"CNS\", \"other\"]\n\n\ndef _patient_inputs():\n \"\"\"Create patient-demographics input widgets inside the current gr.Blocks context.\"\"\"\n with gr.Row():\n age = gr.Number(label=\"Age (years)\", value=65, minimum=0, maximum=120, precision=0)\n weight = gr.Number(label=\"Weight (kg)\", value=70.0, minimum=1, maximum=300)\n height = gr.Number(label=\"Height (cm)\", value=170.0,minimum=50, maximum=250)\n with gr.Row():\n sex = gr.Dropdown(label=\"Biological sex\", choices=[\"male\", \"female\"], value=\"male\")\n creatinine = gr.Number(label=\"Serum Creatinine (mg/dL)\", value=1.2, minimum=0.1, maximum=20.0)\n infection_site = gr.Dropdown(label=\"Infection site\", choices=INFECTION_SITES, value=\"urinary\")\n suspected_source = gr.Textbox(label=\"Suspected source\",\n placeholder=\"e.g., community-acquired UTI\")\n with gr.Row():\n medications = gr.Textbox(label=\"Current medications (one per line)\",\n placeholder=\"Metformin\\nLisinopril\", lines=3)\n allergies = gr.Textbox(label=\"Drug allergies (one per line)\",\n placeholder=\"Penicillin\\nSulfa\", lines=3)\n comorbidities = gr.Textbox(label=\"Comorbidities / MDR risk factors (one per line)\",\n placeholder=\"Diabetes\\nCKD\\nPrior MRSA\", lines=3)\n return [age, weight, height, sex, creatinine, infection_site,\n suspected_source, medications, allergies, comorbidities]\n\n\nwith gr.Blocks(title=\"AMR-Guard · Med-I-C\", theme=gr.themes.Soft()) as demo:\n\n gr.Markdown(\"\"\"\n# ⚕ AMR-Guard — Infection Lifecycle Orchestrator\n\n**Multi-Agent Clinical Decision Support for Antimicrobial Stewardship**\n\n| Model | Agent(s) | Role |\n|---|---|---|\n| `google/medgemma-4b-it` | 1, 2, 4 | Intake · Lab extraction · Final Rx |\n| `google/medgemma-4b-it` (27B sub on T4) | 3 | MIC trend analysis |\n| `google/txgemma-2b-predict` (9B sub on T4) | 4 (safety) | Drug interaction screening |\n\n> ⚠ **Research demo only** — not validated for clinical use. All output must be reviewed by a licensed clinician.\n---\n\"\"\")\n\n with gr.Tabs():\n\n # ──────────────────────────────────────────────────────────────────────\n # TAB 1 — Stage 1: Empirical (no lab results)\n # ──────────────────────────────────────────────────────────────────────\n with gr.Tab(\"Stage 1 — Empirical (no lab results)\"):\n gr.Markdown(\"\"\"\n**Scenario:** Patient presents without culture / sensitivity data.\n\n**Pipeline:** Agent 1 — *Intake Historian* (MedGemma 4B IT) → Agent 4 — *Clinical Pharmacologist* (MedGemma 4B IT + TxGemma 2B)\n\"\"\")\n emp_inputs = _patient_inputs()\n emp_btn = gr.Button(\"Run Empirical Pipeline\", variant=\"primary\")\n emp_output = gr.Markdown(label=\"Recommendation\")\n\n emp_btn.click(\n fn=run_empirical_scenario,\n inputs=emp_inputs,\n outputs=emp_output,\n )\n\n # ──────────────────────────────────────────────────────────────────────\n # TAB 2 — Stage 2: Targeted (culture & sensitivity available)\n # ──────────────────────────────────────────────────────────────────────\n with gr.Tab(\"Stage 2 — Targeted (lab results available)\"):\n gr.Markdown(\"\"\"\n**Scenario:** Culture & sensitivity report (any language) is available.\n\n**Pipeline:** Agent 1 (MedGemma 4B IT) → Agent 2 — *Vision Specialist* (MedGemma 4B IT) → Agent 3 — *Trend Analyst* (MedGemma 27B→4B sub) → Agent 4 (MedGemma 4B IT + TxGemma 2B)\n\"\"\")\n tgt_inputs = _patient_inputs()\n tgt_labs = gr.Textbox(\n label=\"Lab / Culture Report — paste text (any language)\",\n placeholder=(\n \"Organism: Escherichia coli\\n\"\n \"Ciprofloxacin: S MIC 0.25 mg/L\\n\"\n \"Nitrofurantoin: S MIC 16 mg/L\\n\"\n \"Ampicillin: R MIC >32 mg/L\"\n ),\n lines=6,\n )\n tgt_btn = gr.Button(\"Run Targeted Pipeline\", variant=\"primary\")\n\n with gr.Row():\n tgt_rec_output = gr.Markdown(label=\"Recommendation\")\n tgt_lab_output = gr.Markdown(label=\"Lab Analysis & MIC Trend\")\n\n tgt_btn.click(\n fn=run_targeted_scenario,\n inputs=tgt_inputs + [tgt_labs],\n outputs=[tgt_rec_output, tgt_lab_output],\n )\n\n gr.Markdown(\"\"\"\n---\n**Knowledge bases:** EUCAST v16.0 · WHO AWaRe 2024 · IDSA AMR Guidance 2024 · ATLAS Surveillance · WHO GLASS · DDInter 2.0 \n**Inference:** HuggingFace Transformers · 4-bit quantization · Kaggle T4 GPU\n\"\"\")\n\nprint(\"Gradio app defined. Run the next cell to launch.\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  },
280
  {
281
  "cell_type": "code",
 
283
  "id": "00ecfb17",
284
  "metadata": {},
285
  "outputs": [],
286
+ "source": "# share=True creates a public Gradio URL (works out-of-the-box on Kaggle — no localtunnel needed).\n# The URL is printed below and stays live for ~72 hours.\ndemo.launch(share=True, quiet=True)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  }
288
  ],
289
  "metadata": {
 
299
  },
300
  "nbformat": 4,
301
  "nbformat_minor": 5
302
+ }
src/agents.py CHANGED
@@ -1,18 +1,15 @@
1
  """
2
- Multi-Agent System.
3
 
4
- Implements the 4 specialized agents for the infection lifecycle workflow:
5
- - Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
6
- - Agent 2: Vision Specialist - Extract structured data from lab reports
7
- - Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
8
- - Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
9
  """
10
 
11
- from __future__ import annotations
12
-
13
  import json
14
  import logging
15
- from typing import Any, Dict, Optional
16
 
17
  from .config import get_settings
18
  from .loader import run_inference, TextModelName
@@ -40,36 +37,12 @@ from .utils import (
40
  logger = logging.getLogger(__name__)
41
 
42
 
43
- # =============================================================================
44
- # AGENT 1: INTAKE HISTORIAN
45
- # =============================================================================
46
-
47
  def run_intake_historian(state: InfectionState) -> InfectionState:
48
- """
49
- Agent 1: Parse patient data, calculate CrCl, identify risk factors.
50
-
51
- Input state fields used:
52
- - age_years, weight_kg, height_cm, sex
53
- - serum_creatinine_mg_dl
54
- - medications, allergies, comorbidities
55
- - suspected_source, infection_site
56
-
57
- Output state fields updated:
58
- - creatinine_clearance_ml_min
59
- - intake_notes
60
- - stage (empirical/targeted)
61
- - route_to_vision
62
- """
63
  logger.info("Running Intake Historian agent...")
64
 
65
- # Calculate CrCl if we have the required data
66
  crcl = None
67
- if all([
68
- state.get("age_years"),
69
- state.get("weight_kg"),
70
- state.get("serum_creatinine_mg_dl"),
71
- state.get("sex"),
72
- ]):
73
  try:
74
  crcl = calculate_crcl(
75
  age_years=state["age_years"],
@@ -85,20 +58,14 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
85
  logger.warning(f"Could not calculate CrCl: {e}")
86
  state.setdefault("errors", []).append(f"CrCl calculation error: {e}")
87
 
88
- # Build patient data string for prompt
89
  patient_data = _format_patient_data(state)
90
-
91
- # Get RAG context
92
  query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}"
93
  rag_context = get_context_for_agent(
94
  agent_name="intake_historian",
95
  query=query,
96
- patient_context={
97
- "pathogen_type": state.get("suspected_source"),
98
- },
99
  )
100
 
101
- # Format the prompt
102
  prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
103
  patient_data=patient_data,
104
  medications=', '.join(state.get('medications', [])) or 'None reported',
@@ -108,34 +75,20 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
108
  rag_context=rag_context,
109
  )}"
110
 
111
- # Run inference
112
  try:
113
- response = run_inference(
114
- prompt=prompt,
115
- model_name="medgemma_4b",
116
- max_new_tokens=1024,
117
- temperature=0.2,
118
- )
119
-
120
- # Parse response
121
  parsed = safe_json_parse(response)
122
  if parsed:
123
  state["intake_notes"] = json.dumps(parsed, indent=2)
124
-
125
- # Update state from parsed response
126
  if parsed.get("creatinine_clearance_ml_min") and crcl is None:
127
  state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"]
128
-
129
- # Determine stage
130
- recommended_stage = parsed.get("recommended_stage", "empirical")
131
- state["stage"] = recommended_stage
132
-
133
- # Route to vision if we have lab data to process
134
- state["route_to_vision"] = bool(state.get("labs_raw_text"))
135
  else:
136
  state["intake_notes"] = response
137
  state["stage"] = "empirical"
138
- state["route_to_vision"] = bool(state.get("labs_raw_text"))
 
 
139
 
140
  except Exception as e:
141
  logger.error(f"Intake Historian error: {e}")
@@ -147,23 +100,8 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
147
  return state
148
 
149
 
150
- # =============================================================================
151
- # AGENT 2: VISION SPECIALIST
152
- # =============================================================================
153
-
154
  def run_vision_specialist(state: InfectionState) -> InfectionState:
155
- """
156
- Agent 2: Extract structured data from lab reports (text, images, PDFs).
157
-
158
- Input state fields used:
159
- - labs_raw_text (extracted text from lab report)
160
-
161
- Output state fields updated:
162
- - labs_parsed
163
- - mic_data
164
- - vision_notes
165
- - route_to_trend_analyst
166
- """
167
  logger.info("Running Vision Specialist agent...")
168
 
169
  labs_raw = state.get("labs_raw_text", "")
@@ -173,68 +111,54 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
173
  state["route_to_trend_analyst"] = False
174
  return state
175
 
176
- # Detect language (simplified - in production would use langdetect)
177
  language = "English (assumed)"
178
-
179
- # Get RAG context for lab interpretation
180
  rag_context = get_context_for_agent(
181
  agent_name="vision_specialist",
182
  query="culture sensitivity susceptibility interpretation",
183
  patient_context={},
184
  )
185
 
186
- # Format the prompt
187
  prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
188
  report_content=labs_raw,
189
  source_format='text',
190
  language=language,
191
  )}"
192
 
193
- # Run inference
194
  try:
195
- response = run_inference(
196
- prompt=prompt,
197
- model_name="medgemma_4b",
198
- max_new_tokens=2048,
199
- temperature=0.1,
200
- )
201
-
202
- # Parse response
203
  parsed = safe_json_parse(response)
204
  if parsed:
205
  state["vision_notes"] = json.dumps(parsed, indent=2)
206
 
207
- # Extract organisms and susceptibility data
208
  organisms = parsed.get("identified_organisms", [])
209
  susceptibility = parsed.get("susceptibility_results", [])
210
 
211
- # Convert to MICDatum format
212
- mic_data = []
213
- for result in susceptibility:
214
- mic_datum = {
215
- "organism": normalize_organism_name(result.get("organism", "")),
216
- "antibiotic": normalize_antibiotic_name(result.get("antibiotic", "")),
217
- "mic_value": str(result.get("mic_value", "")),
218
- "mic_unit": result.get("mic_unit", "mg/L"),
219
- "interpretation": result.get("interpretation"),
220
  }
221
- mic_data.append(mic_datum)
 
222
 
223
  state["mic_data"] = mic_data
224
- state["labs_parsed"] = [{
225
- "name": org.get("organism_name", "Unknown"),
226
- "value": org.get("colony_count", ""),
227
- "flag": "pathogen" if org.get("significance") == "pathogen" else None,
228
- } for org in organisms]
229
-
230
- # Route to trend analyst if we have MIC data
 
231
  state["route_to_trend_analyst"] = len(mic_data) > 0
232
 
233
- # Check for critical findings
234
  critical = parsed.get("critical_findings", [])
235
  if critical:
236
  state.setdefault("safety_warnings", []).extend(critical)
237
-
238
  else:
239
  state["vision_notes"] = response
240
  state["route_to_trend_analyst"] = False
@@ -249,23 +173,8 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
249
  return state
250
 
251
 
252
- # =============================================================================
253
- # AGENT 3: TREND ANALYST
254
- # =============================================================================
255
-
256
  def run_trend_analyst(state: InfectionState) -> InfectionState:
257
- """
258
- Agent 3: Analyze MIC trends and detect resistance velocity.
259
-
260
- Input state fields used:
261
- - mic_data (current MIC readings)
262
- - Historical MIC data (if available)
263
-
264
- Output state fields updated:
265
- - mic_trend_summary
266
- - trend_notes
267
- - safety_warnings (if high risk detected)
268
- """
269
  logger.info("Running Trend Analyst agent...")
270
 
271
  mic_data = state.get("mic_data", [])
@@ -274,14 +183,12 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
274
  state["trend_notes"] = "No MIC data available for trend analysis"
275
  return state
276
 
277
- # For each organism-antibiotic pair, analyze trends
278
  trend_results = []
279
 
280
  for mic in mic_data:
281
  organism = mic.get("organism", "Unknown")
282
  antibiotic = mic.get("antibiotic", "Unknown")
283
 
284
- # Get RAG context for breakpoints
285
  rag_context = get_context_for_agent(
286
  agent_name="trend_analyst",
287
  query=f"breakpoint {organism} {antibiotic}",
@@ -292,10 +199,9 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
292
  },
293
  )
294
 
295
- # Format MIC history (in production, would pull from database)
296
  mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}]
297
 
298
- # Format prompt
299
  prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format(
300
  organism=organism,
301
  antibiotic=antibiotic,
@@ -305,18 +211,16 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
305
  )}"
306
 
307
  try:
 
308
  response = run_inference(
309
  prompt=prompt,
310
- model_name="medgemma_27b", # Agent 3: MedGemma 27B per PLAN.md (env maps to 4B on limited GPU)
311
  max_new_tokens=1024,
312
  temperature=0.2,
313
  )
314
-
315
  parsed = safe_json_parse(response)
316
  if parsed:
317
  trend_results.append(parsed)
318
-
319
- # Add safety warning if high/critical risk
320
  risk_level = parsed.get("risk_level", "LOW")
321
  if risk_level in ["HIGH", "CRITICAL"]:
322
  warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}"
@@ -328,10 +232,8 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
328
  logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}")
329
  trend_results.append({"error": str(e)})
330
 
331
- # Summarize trends
332
  state["trend_notes"] = json.dumps(trend_results, indent=2)
333
 
334
- # Create summary
335
  high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"])
336
  state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}"
337
 
@@ -339,43 +241,21 @@ def run_trend_analyst(state: InfectionState) -> InfectionState:
339
  return state
340
 
341
 
342
- # =============================================================================
343
- # AGENT 4: CLINICAL PHARMACOLOGIST
344
- # =============================================================================
345
-
346
  def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
347
- """
348
- Agent 4: Generate final antibiotic recommendation with safety checks.
349
-
350
- Input state fields used:
351
- - intake_notes, vision_notes, trend_notes
352
- - age_years, weight_kg, creatinine_clearance_ml_min
353
- - allergies, medications
354
- - infection_site, suspected_source
355
-
356
- Output state fields updated:
357
- - recommendation
358
- - pharmacology_notes
359
- - safety_warnings (additional alerts)
360
- """
361
  logger.info("Running Clinical Pharmacologist agent...")
362
 
363
- # Gather all previous agent outputs
364
  intake_summary = state.get("intake_notes", "No intake data")
365
  lab_results = state.get("vision_notes", "No lab data")
366
  trend_analysis = state.get("trend_notes", "No trend data")
367
 
368
- # Get RAG context
369
  query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation"
370
  rag_context = get_context_for_agent(
371
  agent_name="clinical_pharmacologist",
372
  query=query,
373
- patient_context={
374
- "proposed_antibiotic": None, # Will be determined by agent
375
- },
376
  )
377
 
378
- # Format prompt
379
  prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
380
  intake_summary=intake_summary,
381
  lab_results=lab_results,
@@ -392,18 +272,11 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
392
  )}"
393
 
394
  try:
395
- response = run_inference(
396
- prompt=prompt,
397
- model_name="medgemma_4b",
398
- max_new_tokens=2048,
399
- temperature=0.2,
400
- )
401
-
402
  parsed = safe_json_parse(response)
403
  if parsed:
404
  state["pharmacology_notes"] = json.dumps(parsed, indent=2)
405
 
406
- # Build recommendation
407
  primary = parsed.get("primary_recommendation", {})
408
  recommendation = {
409
  "primary_antibiotic": primary.get("antibiotic"),
@@ -416,19 +289,16 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
416
  "safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])],
417
  }
418
 
419
- # Add alternative if provided
420
  alt = parsed.get("alternative_recommendation", {})
421
  if alt.get("antibiotic"):
422
  recommendation["backup_antibiotic"] = alt.get("antibiotic")
423
 
424
  state["recommendation"] = recommendation
425
 
426
- # Add safety alerts to state
427
  for alert in parsed.get("safety_alerts", []):
428
  if alert.get("level") in ["WARNING", "CRITICAL"]:
429
  state.setdefault("safety_warnings", []).append(alert.get("message"))
430
 
431
- # Run TxGemma safety check (optional)
432
  if primary.get("antibiotic"):
433
  safety_result = _run_txgemma_safety_check(
434
  antibiotic=primary.get("antibiotic"),
@@ -441,7 +311,6 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
441
  )
442
  if safety_result:
443
  state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}")
444
-
445
  else:
446
  state["pharmacology_notes"] = response
447
  state["recommendation"] = {"rationale": response}
@@ -455,12 +324,8 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
455
  return state
456
 
457
 
458
- # =============================================================================
459
- # HELPER FUNCTIONS
460
- # =============================================================================
461
-
462
  def _format_patient_data(state: InfectionState) -> str:
463
- """Format patient data for prompt injection."""
464
  lines = []
465
 
466
  if state.get("patient_id"):
@@ -505,11 +370,7 @@ def _run_txgemma_safety_check(
505
  crcl: Optional[float],
506
  medications: list,
507
  ) -> Optional[str]:
508
- """
509
- Run TxGemma safety check (supplementary).
510
-
511
- TxGemma is used only for safety validation, not primary recommendations.
512
- """
513
  try:
514
  prompt = TXGEMMA_SAFETY_PROMPT.format(
515
  antibiotic=antibiotic,
@@ -520,25 +381,13 @@ def _run_txgemma_safety_check(
520
  crcl=crcl or "Unknown",
521
  medications=", ".join(medications) if medications else "None",
522
  )
523
-
524
- response = run_inference(
525
- prompt=prompt,
526
- model_name="txgemma_9b", # Agent 4 safety: TxGemma 9B per PLAN.md (env maps to 2B on limited GPU)
527
- max_new_tokens=256,
528
- temperature=0.1,
529
- )
530
-
531
- return response
532
-
533
  except Exception as e:
534
  logger.warning(f"TxGemma safety check failed: {e}")
535
  return None
536
 
537
 
538
- # =============================================================================
539
- # AGENT REGISTRY
540
- # =============================================================================
541
-
542
  AGENTS = {
543
  "intake_historian": run_intake_historian,
544
  "vision_specialist": run_vision_specialist,
@@ -548,27 +397,7 @@ AGENTS = {
548
 
549
 
550
  def run_agent(agent_name: str, state: InfectionState) -> InfectionState:
551
- """
552
- Run a specific agent by name.
553
-
554
- Args:
555
- agent_name: Name of the agent to run
556
- state: Current infection state
557
-
558
- Returns:
559
- Updated infection state
560
- """
561
  if agent_name not in AGENTS:
562
  raise ValueError(f"Unknown agent: {agent_name}")
563
-
564
  return AGENTS[agent_name](state)
565
-
566
-
567
- __all__ = [
568
- "run_intake_historian",
569
- "run_vision_specialist",
570
- "run_trend_analyst",
571
- "run_clinical_pharmacologist",
572
- "run_agent",
573
- "AGENTS",
574
- ]
 
1
  """
2
+ Four-agent pipeline for the infection lifecycle workflow.
3
 
4
+ Agent 1 - Intake Historian: parse patient data, calculate CrCl, identify AMR risk factors
5
+ Agent 2 - Vision Specialist: extract organisms and MIC values from lab reports
6
+ Agent 3 - Trend Analyst: detect MIC creep and resistance velocity
7
+ Agent 4 - Clinical Pharmacologist: generate final antibiotic recommendation with safety checks
 
8
  """
9
 
 
 
10
  import json
11
  import logging
12
+ from typing import Optional
13
 
14
  from .config import get_settings
15
  from .loader import run_inference, TextModelName
 
37
  logger = logging.getLogger(__name__)
38
 
39
 
 
 
 
 
40
  def run_intake_historian(state: InfectionState) -> InfectionState:
41
+ """Parse patient data, calculate CrCl, identify MDR risk factors, and set the treatment stage."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  logger.info("Running Intake Historian agent...")
43
 
 
44
  crcl = None
45
+ if all([state.get("age_years"), state.get("weight_kg"), state.get("serum_creatinine_mg_dl"), state.get("sex")]):
 
 
 
 
 
46
  try:
47
  crcl = calculate_crcl(
48
  age_years=state["age_years"],
 
58
  logger.warning(f"Could not calculate CrCl: {e}")
59
  state.setdefault("errors", []).append(f"CrCl calculation error: {e}")
60
 
 
61
  patient_data = _format_patient_data(state)
 
 
62
  query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}"
63
  rag_context = get_context_for_agent(
64
  agent_name="intake_historian",
65
  query=query,
66
+ patient_context={"pathogen_type": state.get("suspected_source")},
 
 
67
  )
68
 
 
69
  prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
70
  patient_data=patient_data,
71
  medications=', '.join(state.get('medications', [])) or 'None reported',
 
75
  rag_context=rag_context,
76
  )}"
77
 
 
78
  try:
79
+ response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=1024, temperature=0.2)
 
 
 
 
 
 
 
80
  parsed = safe_json_parse(response)
81
  if parsed:
82
  state["intake_notes"] = json.dumps(parsed, indent=2)
 
 
83
  if parsed.get("creatinine_clearance_ml_min") and crcl is None:
84
  state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"]
85
+ state["stage"] = parsed.get("recommended_stage", "empirical")
 
 
 
 
 
 
86
  else:
87
  state["intake_notes"] = response
88
  state["stage"] = "empirical"
89
+
90
+ # Route to vision only if lab text was provided
91
+ state["route_to_vision"] = bool(state.get("labs_raw_text"))
92
 
93
  except Exception as e:
94
  logger.error(f"Intake Historian error: {e}")
 
100
  return state
101
 
102
 
 
 
 
 
103
  def run_vision_specialist(state: InfectionState) -> InfectionState:
104
+ """Extract pathogen names, MIC values, and S/I/R interpretations from lab report text."""
 
 
 
 
 
 
 
 
 
 
 
105
  logger.info("Running Vision Specialist agent...")
106
 
107
  labs_raw = state.get("labs_raw_text", "")
 
111
  state["route_to_trend_analyst"] = False
112
  return state
113
 
114
+ # Language detection is not implemented; we assume English or instruct the model to translate
115
  language = "English (assumed)"
 
 
116
  rag_context = get_context_for_agent(
117
  agent_name="vision_specialist",
118
  query="culture sensitivity susceptibility interpretation",
119
  patient_context={},
120
  )
121
 
 
122
  prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
123
  report_content=labs_raw,
124
  source_format='text',
125
  language=language,
126
  )}"
127
 
 
128
  try:
129
+ response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1)
 
 
 
 
 
 
 
130
  parsed = safe_json_parse(response)
131
  if parsed:
132
  state["vision_notes"] = json.dumps(parsed, indent=2)
133
 
 
134
  organisms = parsed.get("identified_organisms", [])
135
  susceptibility = parsed.get("susceptibility_results", [])
136
 
137
+ mic_data = [
138
+ {
139
+ "organism": normalize_organism_name(r.get("organism", "")),
140
+ "antibiotic": normalize_antibiotic_name(r.get("antibiotic", "")),
141
+ "mic_value": str(r.get("mic_value", "")),
142
+ "mic_unit": r.get("mic_unit", "mg/L"),
143
+ "interpretation": r.get("interpretation"),
 
 
144
  }
145
+ for r in susceptibility
146
+ ]
147
 
148
  state["mic_data"] = mic_data
149
+ state["labs_parsed"] = [
150
+ {
151
+ "name": org.get("organism_name", "Unknown"),
152
+ "value": org.get("colony_count", ""),
153
+ "flag": "pathogen" if org.get("significance") == "pathogen" else None,
154
+ }
155
+ for org in organisms
156
+ ]
157
  state["route_to_trend_analyst"] = len(mic_data) > 0
158
 
 
159
  critical = parsed.get("critical_findings", [])
160
  if critical:
161
  state.setdefault("safety_warnings", []).extend(critical)
 
162
  else:
163
  state["vision_notes"] = response
164
  state["route_to_trend_analyst"] = False
 
173
  return state
174
 
175
 
 
 
 
 
176
  def run_trend_analyst(state: InfectionState) -> InfectionState:
177
+ """Analyze MIC trends per organism-antibiotic pair and flag high-risk creep."""
 
 
 
 
 
 
 
 
 
 
 
178
  logger.info("Running Trend Analyst agent...")
179
 
180
  mic_data = state.get("mic_data", [])
 
183
  state["trend_notes"] = "No MIC data available for trend analysis"
184
  return state
185
 
 
186
  trend_results = []
187
 
188
  for mic in mic_data:
189
  organism = mic.get("organism", "Unknown")
190
  antibiotic = mic.get("antibiotic", "Unknown")
191
 
 
192
  rag_context = get_context_for_agent(
193
  agent_name="trend_analyst",
194
  query=f"breakpoint {organism} {antibiotic}",
 
199
  },
200
  )
201
 
202
+ # Single time-point history trend analysis requires historical data in production
203
  mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}]
204
 
 
205
  prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format(
206
  organism=organism,
207
  antibiotic=antibiotic,
 
211
  )}"
212
 
213
  try:
214
+ # Agent 3 is designed for MedGemma 27B; on limited GPU the env var maps this to 4B
215
  response = run_inference(
216
  prompt=prompt,
217
+ model_name="medgemma_27b",
218
  max_new_tokens=1024,
219
  temperature=0.2,
220
  )
 
221
  parsed = safe_json_parse(response)
222
  if parsed:
223
  trend_results.append(parsed)
 
 
224
  risk_level = parsed.get("risk_level", "LOW")
225
  if risk_level in ["HIGH", "CRITICAL"]:
226
  warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}"
 
232
  logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}")
233
  trend_results.append({"error": str(e)})
234
 
 
235
  state["trend_notes"] = json.dumps(trend_results, indent=2)
236
 
 
237
  high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"])
238
  state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}"
239
 
 
241
  return state
242
 
243
 
 
 
 
 
244
  def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
245
+ """Synthesize all agent outputs into a final antibiotic recommendation with safety checks."""
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  logger.info("Running Clinical Pharmacologist agent...")
247
 
 
248
  intake_summary = state.get("intake_notes", "No intake data")
249
  lab_results = state.get("vision_notes", "No lab data")
250
  trend_analysis = state.get("trend_notes", "No trend data")
251
 
 
252
  query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation"
253
  rag_context = get_context_for_agent(
254
  agent_name="clinical_pharmacologist",
255
  query=query,
256
+ patient_context={"proposed_antibiotic": None},
 
 
257
  )
258
 
 
259
  prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
260
  intake_summary=intake_summary,
261
  lab_results=lab_results,
 
272
  )}"
273
 
274
  try:
275
+ response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.2)
 
 
 
 
 
 
276
  parsed = safe_json_parse(response)
277
  if parsed:
278
  state["pharmacology_notes"] = json.dumps(parsed, indent=2)
279
 
 
280
  primary = parsed.get("primary_recommendation", {})
281
  recommendation = {
282
  "primary_antibiotic": primary.get("antibiotic"),
 
289
  "safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])],
290
  }
291
 
 
292
  alt = parsed.get("alternative_recommendation", {})
293
  if alt.get("antibiotic"):
294
  recommendation["backup_antibiotic"] = alt.get("antibiotic")
295
 
296
  state["recommendation"] = recommendation
297
 
 
298
  for alert in parsed.get("safety_alerts", []):
299
  if alert.get("level") in ["WARNING", "CRITICAL"]:
300
  state.setdefault("safety_warnings", []).append(alert.get("message"))
301
 
 
302
  if primary.get("antibiotic"):
303
  safety_result = _run_txgemma_safety_check(
304
  antibiotic=primary.get("antibiotic"),
 
311
  )
312
  if safety_result:
313
  state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}")
 
314
  else:
315
  state["pharmacology_notes"] = response
316
  state["recommendation"] = {"rationale": response}
 
324
  return state
325
 
326
 
 
 
 
 
327
  def _format_patient_data(state: InfectionState) -> str:
328
+ """Format patient fields from state into a readable string for prompt injection."""
329
  lines = []
330
 
331
  if state.get("patient_id"):
 
370
  crcl: Optional[float],
371
  medications: list,
372
  ) -> Optional[str]:
373
+ """Run a supplementary TxGemma toxicology check on the proposed prescription."""
 
 
 
 
374
  try:
375
  prompt = TXGEMMA_SAFETY_PROMPT.format(
376
  antibiotic=antibiotic,
 
381
  crcl=crcl or "Unknown",
382
  medications=", ".join(medications) if medications else "None",
383
  )
384
+ # Agent 4 safety check uses TxGemma 9B; on limited GPU the env var maps this to 2B
385
+ return run_inference(prompt=prompt, model_name="txgemma_9b", max_new_tokens=256, temperature=0.1)
 
 
 
 
 
 
 
 
386
  except Exception as e:
387
  logger.warning(f"TxGemma safety check failed: {e}")
388
  return None
389
 
390
 
 
 
 
 
391
  AGENTS = {
392
  "intake_historian": run_intake_historian,
393
  "vision_specialist": run_vision_specialist,
 
397
 
398
 
399
  def run_agent(agent_name: str, state: InfectionState) -> InfectionState:
400
+ """Dispatch to a named agent."""
 
 
 
 
 
 
 
 
 
401
  if agent_name not in AGENTS:
402
  raise ValueError(f"Unknown agent: {agent_name}")
 
403
  return AGENTS[agent_name](state)
 
 
 
 
 
 
 
 
 
 
src/config.py CHANGED
@@ -1,6 +1,4 @@
1
 
2
- from __future__ import annotations
3
-
4
  import os
5
  from functools import lru_cache
6
  from pathlib import Path
@@ -9,104 +7,63 @@ from typing import Literal, Optional
9
  from dotenv import load_dotenv
10
  from pydantic import BaseModel, Field
11
 
12
-
13
- # Load variables from a local .env if present (handy for local dev)
14
  load_dotenv()
15
 
16
 
17
  class Settings(BaseModel):
18
  """
19
- Central configuration object for Med-I-C.
20
 
21
- Values are read from environment variables where possible so that
22
- the same code can run locally, on Kaggle, and in production.
23
  """
24
 
25
- # ------------------------------------------------------------------
26
- # General environment
27
- # ------------------------------------------------------------------
28
  environment: Literal["local", "kaggle", "production"] = Field(
29
  default_factory=lambda: os.getenv("MEDIC_ENV", "local")
30
  )
31
-
32
  project_root: Path = Field(
33
  default_factory=lambda: Path(__file__).resolve().parents[1]
34
  )
35
-
36
  data_dir: Path = Field(
37
- default_factory=lambda: Path(
38
- os.getenv("MEDIC_DATA_DIR", "data")
39
- )
40
  )
41
-
42
  chroma_db_dir: Path = Field(
43
- default_factory=lambda: Path(
44
- os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db")
45
- )
46
  )
47
 
48
- # ------------------------------------------------------------------
49
- # Model + deployment preferences
50
- # ------------------------------------------------------------------
51
  default_backend: Literal["vertex", "local"] = Field(
52
  default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "vertex") # type: ignore[arg-type]
53
  )
54
-
55
- # Quantization mode for local models
56
  quantization: Literal["none", "4bit"] = Field(
57
  default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
58
  )
59
-
60
- # Embedding model used for ChromaDB / RAG
61
  embedding_model_name: str = Field(
62
- default_factory=lambda: os.getenv(
63
- "MEDIC_EMBEDDING_MODEL",
64
- "sentence-transformers/all-MiniLM-L6-v2",
65
- )
66
  )
67
 
68
- # ------------------------------------------------------------------
69
- # Vertex AI configuration (MedGemma / TxGemma hosted on Vertex)
70
- # ------------------------------------------------------------------
71
  use_vertex: bool = Field(
72
- default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower()
73
- in {"1", "true", "yes"}
74
  )
75
-
76
  vertex_project_id: Optional[str] = Field(
77
  default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
78
  )
79
  vertex_location: str = Field(
80
  default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
81
  )
82
-
83
- # Model IDs as expected by Vertex / langchain-google-vertexai
84
  vertex_medgemma_4b_model: str = Field(
85
- default_factory=lambda: os.getenv(
86
- "MEDIC_VERTEX_MEDGEMMA_4B_MODEL",
87
- "med-gemma-4b-it",
88
- )
89
  )
90
  vertex_medgemma_27b_model: str = Field(
91
- default_factory=lambda: os.getenv(
92
- "MEDIC_VERTEX_MEDGEMMA_27B_MODEL",
93
- "med-gemma-27b-text-it",
94
- )
95
  )
96
  vertex_txgemma_9b_model: str = Field(
97
- default_factory=lambda: os.getenv(
98
- "MEDIC_VERTEX_TXGEMMA_9B_MODEL",
99
- "tx-gemma-9b",
100
- )
101
  )
102
  vertex_txgemma_2b_model: str = Field(
103
- default_factory=lambda: os.getenv(
104
- "MEDIC_VERTEX_TXGEMMA_2B_MODEL",
105
- "tx-gemma-2b",
106
- )
107
  )
108
-
109
- # Standard GOOGLE_APPLICATION_CREDENTIALS path, if needed
110
  google_application_credentials: Optional[Path] = Field(
111
  default_factory=lambda: (
112
  Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
@@ -115,9 +72,7 @@ class Settings(BaseModel):
115
  )
116
  )
117
 
118
- # ------------------------------------------------------------------
119
- # Local model paths (for offline / Kaggle GPU usage)
120
- # ------------------------------------------------------------------
121
  local_medgemma_4b_model: Optional[str] = Field(
122
  default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
123
  )
@@ -134,17 +89,6 @@ class Settings(BaseModel):
134
 
135
  @lru_cache(maxsize=1)
136
  def get_settings() -> Settings:
137
- """
138
- Return a cached Settings instance.
139
-
140
- Use this helper everywhere instead of instantiating Settings directly:
141
-
142
- from src.config import get_settings
143
- settings = get_settings()
144
- """
145
-
146
  return Settings()
147
 
148
-
149
- __all__ = ["Settings", "get_settings"]
150
-
 
1
 
 
 
2
  import os
3
  from functools import lru_cache
4
  from pathlib import Path
 
7
  from dotenv import load_dotenv
8
  from pydantic import BaseModel, Field
9
 
 
 
10
  load_dotenv()
11
 
12
 
13
  class Settings(BaseModel):
14
  """
15
+ All configuration for Med-I-C, read from environment variables.
16
 
17
+ Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
18
+ Backend selection (vertex or local) is controlled by MEDIC_DEFAULT_BACKEND.
19
  """
20
 
 
 
 
21
  environment: Literal["local", "kaggle", "production"] = Field(
22
  default_factory=lambda: os.getenv("MEDIC_ENV", "local")
23
  )
 
24
  project_root: Path = Field(
25
  default_factory=lambda: Path(__file__).resolve().parents[1]
26
  )
 
27
  data_dir: Path = Field(
28
+ default_factory=lambda: Path(os.getenv("MEDIC_DATA_DIR", "data"))
 
 
29
  )
 
30
  chroma_db_dir: Path = Field(
31
+ default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
 
 
32
  )
33
 
 
 
 
34
  default_backend: Literal["vertex", "local"] = Field(
35
  default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "vertex") # type: ignore[arg-type]
36
  )
37
+ # 4-bit quantization via bitsandbytes (local backend only)
 
38
  quantization: Literal["none", "4bit"] = Field(
39
  default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
40
  )
 
 
41
  embedding_model_name: str = Field(
42
+ default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
 
 
 
43
  )
44
 
45
+ # Vertex AI settings
 
 
46
  use_vertex: bool = Field(
47
+ default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower() in {"1", "true", "yes"}
 
48
  )
 
49
  vertex_project_id: Optional[str] = Field(
50
  default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
51
  )
52
  vertex_location: str = Field(
53
  default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
54
  )
 
 
55
  vertex_medgemma_4b_model: str = Field(
56
+ default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_4B_MODEL", "med-gemma-4b-it")
 
 
 
57
  )
58
  vertex_medgemma_27b_model: str = Field(
59
+ default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_27B_MODEL", "med-gemma-27b-text-it")
 
 
 
60
  )
61
  vertex_txgemma_9b_model: str = Field(
62
+ default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_9B_MODEL", "tx-gemma-9b")
 
 
 
63
  )
64
  vertex_txgemma_2b_model: str = Field(
65
+ default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_2B_MODEL", "tx-gemma-2b")
 
 
 
66
  )
 
 
67
  google_application_credentials: Optional[Path] = Field(
68
  default_factory=lambda: (
69
  Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
 
72
  )
73
  )
74
 
75
+ # Local HuggingFace model paths (used when MEDIC_DEFAULT_BACKEND=local)
 
 
76
  local_medgemma_4b_model: Optional[str] = Field(
77
  default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
78
  )
 
89
 
90
  @lru_cache(maxsize=1)
91
  def get_settings() -> Settings:
92
+ """Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
 
 
 
 
 
 
 
 
93
  return Settings()
94
 
 
 
 
src/db/import_data.py CHANGED
@@ -1,7 +1,6 @@
1
  """Data import scripts for Med-I-C structured documents."""
2
 
3
  import pandas as pd
4
- import re
5
  from pathlib import Path
6
  from .database import (
7
  get_connection, init_database, execute_many,
@@ -10,7 +9,7 @@ from .database import (
10
 
11
 
12
  def safe_float(value):
13
- """Safely convert a value to float, returning None on failure."""
14
  if pd.isna(value):
15
  return None
16
  try:
@@ -20,7 +19,7 @@ def safe_float(value):
20
 
21
 
22
  def safe_int(value):
23
- """Safely convert a value to int, returning None on failure."""
24
  if pd.isna(value):
25
  return None
26
  try:
@@ -29,41 +28,46 @@ def safe_int(value):
29
  return None
30
 
31
 
 
 
 
 
 
 
 
32
  def classify_severity(description: str) -> str:
33
- """Classify drug interaction severity based on description keywords."""
 
 
 
 
 
34
  if not description:
35
  return "unknown"
36
 
37
  desc_lower = description.lower()
38
 
39
- # Major severity indicators
40
  major_keywords = [
41
  "cardiotoxic", "nephrotoxic", "hepatotoxic", "neurotoxic",
42
  "fatal", "death", "severe", "contraindicated", "arrhythmia",
43
  "qt prolongation", "seizure", "bleeding", "hemorrhage",
44
- "serotonin syndrome", "neuroleptic malignant"
45
  ]
46
-
47
- # Moderate severity indicators
48
  moderate_keywords = [
49
  "increase", "decrease", "reduce", "enhance", "inhibit",
50
  "metabolism", "concentration", "absorption", "excretion",
51
- "therapeutic effect", "adverse effect", "toxicity"
52
  ]
53
 
54
- for keyword in major_keywords:
55
- if keyword in desc_lower:
56
- return "major"
57
-
58
- for keyword in moderate_keywords:
59
- if keyword in desc_lower:
60
- return "moderate"
61
-
62
  return "minor"
63
 
64
 
65
  def import_eml_antibiotics() -> int:
66
- """Import WHO EML antibiotic classification data."""
67
  print("Importing EML antibiotic data...")
68
 
69
  eml_files = {
@@ -79,29 +83,21 @@ def import_eml_antibiotics() -> int:
79
  continue
80
 
81
  try:
82
- # Use openpyxl directly with read_only=True for faster loading
83
  import openpyxl
84
  wb = openpyxl.load_workbook(filepath, read_only=True)
85
  ws = wb.active
86
 
87
- # Get headers from first row
88
- headers = []
89
- for cell in ws[1]:
90
- headers.append(str(cell.value).strip().lower().replace(' ', '_') if cell.value else f'col_{len(headers)}')
91
 
92
- # Process data rows
93
- for row_idx, row in enumerate(ws.iter_rows(min_row=2, values_only=True), start=2):
94
  row_dict = dict(zip(headers, row))
95
-
96
  medicine = str(row_dict.get('medicine_name', row_dict.get('medicine', '')))
97
- if not medicine or medicine == 'None' or medicine == 'nan':
98
  continue
99
 
100
- def safe_str(val):
101
- if val is None or pd.isna(val):
102
- return ''
103
- return str(val)
104
-
105
  records.append((
106
  medicine,
107
  category,
@@ -114,20 +110,20 @@ def import_eml_antibiotics() -> int:
114
  ))
115
 
116
  wb.close()
117
- print(f" Loaded {len([r for r in records if r[1] == category])} from {category}")
118
 
119
  except Exception as e:
120
  print(f" Warning: Error reading {filepath}: {e}")
121
  continue
122
 
123
  if records:
124
- query = """
125
- INSERT INTO eml_antibiotics
126
- (medicine_name, who_category, eml_section, formulations,
127
- indication, atc_codes, combined_with, status)
128
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
129
- """
130
- execute_many(query, records)
131
  print(f" Imported {len(records)} EML antibiotic records total")
132
 
133
  return len(records)
@@ -143,95 +139,78 @@ def import_atlas_susceptibility() -> int:
143
  print(f" Warning: {filepath} not found, skipping...")
144
  return 0
145
 
146
- # Read the raw data to find the header row and extract region
147
  df_raw = pd.read_excel(filepath, sheet_name="Percent", header=None)
148
 
149
- # Extract region from the title (row 1)
150
  region = "Unknown"
151
- for idx, row in df_raw.head(5).iterrows():
152
  cell = str(row.iloc[0]) if pd.notna(row.iloc[0]) else ""
153
  if "from" in cell.lower():
154
- # Extract country from "Percentage Susceptibility from Argentina"
155
  parts = cell.split("from")
156
  if len(parts) > 1:
157
  region = parts[1].strip()
158
  break
159
 
160
- # Find the header row (contains 'Antibacterial' or 'N')
161
- header_row = 4 # Default
162
  for idx, row in df_raw.head(10).iterrows():
163
  if any('Antibacterial' in str(v) for v in row.values if pd.notna(v)):
164
  header_row = idx
165
  break
166
 
167
- # Read with proper header
168
  df = pd.read_excel(filepath, sheet_name="Percent", header=header_row)
169
-
170
- # Standardize column names
171
  df.columns = [str(col).strip().lower().replace(' ', '_').replace('.', '') for col in df.columns]
172
 
173
  records = []
174
  for _, row in df.iterrows():
175
  antibiotic = str(row.get('antibacterial', ''))
176
-
177
- # Skip empty or non-antibiotic rows
178
  if not antibiotic or antibiotic == 'nan' or 'omitted' in antibiotic.lower():
179
  continue
180
  if 'in vitro' in antibiotic.lower() or 'table cells' in antibiotic.lower():
181
  continue
182
 
183
- # Get susceptibility values
184
- n_value = row.get('n', None)
185
- pct_s = row.get('susc', row.get('susceptible', None))
186
- pct_i = row.get('int', row.get('intermediate', None))
187
- pct_r = row.get('res', row.get('resistant', None))
188
-
189
- # Use safe conversion functions
190
- n_int = safe_int(n_value)
191
- s_float = safe_float(pct_s)
192
 
193
  if n_int is not None and s_float is not None:
194
  records.append((
195
- "General", # Species - will be refined if more data available
196
- "", # Family
197
  antibiotic,
198
  s_float,
199
- safe_float(pct_i),
200
- safe_float(pct_r),
201
  n_int,
202
- 2024, # Year - from the data context
203
  region,
204
- "ATLAS"
205
  ))
206
 
207
  if records:
208
- query = """
209
- INSERT INTO atlas_susceptibility
210
- (species, family, antibiotic, percent_susceptible,
211
- percent_intermediate, percent_resistant, total_isolates,
212
- year, region, source)
213
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
214
- """
215
- execute_many(query, records)
216
  print(f" Imported {len(records)} ATLAS susceptibility records from {region}")
217
 
218
  return len(records)
219
 
220
 
221
  def import_mic_breakpoints() -> int:
222
- """Import EUCAST MIC breakpoint tables."""
223
  print("Importing MIC breakpoint data...")
224
 
225
  filepath = DOCS_DIR / "mic_breakpoints" / "v_16.0__BreakpointTables.xlsx"
226
-
227
  if not filepath.exists():
228
  print(f" Warning: {filepath} not found, skipping...")
229
  return 0
230
 
231
- # Get all sheet names
232
  xl = pd.ExcelFile(filepath)
233
-
234
- # Skip non-pathogen sheets
235
  skip_sheets = {'Content', 'Changes', 'Notes', 'Guidance', 'Dosages',
236
  'Technical uncertainty', 'PK PD breakpoints', 'PK PD cutoffs'}
237
 
@@ -239,58 +218,48 @@ def import_mic_breakpoints() -> int:
239
  for sheet_name in xl.sheet_names:
240
  if sheet_name in skip_sheets:
241
  continue
242
-
243
  try:
244
  df = pd.read_excel(filepath, sheet_name=sheet_name, header=None)
245
-
246
- # Try to find antibiotic data - look for rows with MIC values
247
- pathogen_group = sheet_name
248
-
249
- # Simple heuristic: look for rows that might contain antibiotic names and MIC values
250
- for idx, row in df.iterrows():
251
  row_values = [str(v).strip() for v in row.values if pd.notna(v)]
 
 
 
 
 
 
 
252
 
253
- # Look for rows that might be antibiotic entries
254
- if len(row_values) >= 2:
255
- potential_antibiotic = row_values[0]
256
-
257
- # Skip header-like rows
258
- if any(kw in potential_antibiotic.lower() for kw in
259
- ['antibiotic', 'agent', 'note', 'disk', 'mic', 'breakpoint']):
260
- continue
261
-
262
- # Try to extract MIC values (numbers)
263
- mic_values = []
264
- for v in row_values[1:]:
265
- try:
266
- mic_values.append(float(v.replace('≤', '').replace('>', '').replace('<', '').strip()))
267
- except (ValueError, AttributeError):
268
- pass
269
-
270
- if len(mic_values) >= 2 and len(potential_antibiotic) > 2:
271
- records.append((
272
- pathogen_group,
273
- potential_antibiotic,
274
- None, # route
275
- mic_values[0] if len(mic_values) > 0 else None, # S breakpoint
276
- mic_values[1] if len(mic_values) > 1 else None, # R breakpoint
277
- None, # disk S
278
- None, # disk R
279
- None, # notes
280
- "16.0"
281
- ))
282
  except Exception as e:
283
  print(f" Warning: Could not parse sheet '{sheet_name}': {e}")
284
  continue
285
 
286
  if records:
287
- query = """
288
- INSERT INTO mic_breakpoints
289
- (pathogen_group, antibiotic, route, mic_susceptible, mic_resistant,
290
- disk_susceptible, disk_resistant, notes, eucast_version)
291
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
292
- """
293
- execute_many(query, records)
294
  print(f" Imported {len(records)} MIC breakpoint records")
295
 
296
  return len(records)
@@ -303,36 +272,32 @@ INTERACTIONS_CSV = DOCS_DIR / "drug_safety" / "db_drug_interactions.csv"
303
 
304
  def _resolve_interactions_csv() -> Path | None:
305
  """
306
- Return the path to the drug interactions CSV, downloading it if needed.
307
 
308
- Resolution order:
309
- 1. docs/drug_safety/db_drug_interactions.csv (already present locally)
310
- 2. /kaggle/input/drug-drug-interactions/ (Kaggle notebook with dataset attached)
311
- 3. Kaggle API download (local dev with ~/.kaggle/kaggle.json)
312
  """
313
- # 1. Already present
314
  if INTERACTIONS_CSV.exists():
315
  return INTERACTIONS_CSV
316
 
317
- # 2. Kaggle input mount (dataset added via Kaggle UI)
318
- for candidate in KAGGLE_INPUT_DIR.glob("*.csv") if KAGGLE_INPUT_DIR.exists() else []:
319
- print(f" Found CSV in Kaggle input: {candidate}")
320
- return candidate
321
 
322
- # 3. Download via Kaggle API
323
  print(f" CSV not found — downloading from Kaggle dataset '{KAGGLE_DATASET}' ...")
324
  try:
325
- import kaggle # noqa: F401 triggers credential check
 
326
  dest = INTERACTIONS_CSV.parent
327
  dest.mkdir(parents=True, exist_ok=True)
328
- import subprocess
329
  result = subprocess.run(
330
- ["kaggle", "datasets", "download", "-d", KAGGLE_DATASET,
331
- "--unzip", "-p", str(dest)],
332
  capture_output=True, text=True,
333
  )
334
  if result.returncode == 0:
335
- # Find the downloaded CSV
336
  for f in dest.glob("*.csv"):
337
  print(f" Downloaded: {f.name}")
338
  return f
@@ -347,23 +312,18 @@ def _resolve_interactions_csv() -> Path | None:
347
 
348
 
349
  def import_drug_interactions(limit: int = None) -> int:
350
- """Import drug-drug interaction database from Kaggle dataset mghobashy/drug-drug-interactions."""
351
  print("Importing drug interactions data...")
352
 
353
  filepath = _resolve_interactions_csv()
354
-
355
  if filepath is None:
356
  print(" Skipping drug interactions — CSV unavailable.")
357
  print(f" To fix: attach the Kaggle dataset '{KAGGLE_DATASET}' to your notebook,")
358
  print(" or set up ~/.kaggle/kaggle.json for API access.")
359
  return 0
360
 
361
- # Read CSV in chunks due to large size
362
- chunk_size = 10000
363
  total_records = 0
364
-
365
- for chunk in pd.read_csv(filepath, chunksize=chunk_size):
366
- # Standardize column names
367
  chunk.columns = [col.strip().lower().replace(' ', '_') for col in chunk.columns]
368
 
369
  records = []
@@ -372,19 +332,14 @@ def import_drug_interactions(limit: int = None) -> int:
372
  drug_2 = str(row.get('drug_2', row.get('drug2', row.iloc[1] if len(row) > 1 else '')))
373
  description = str(row.get('interaction_description', row.get('description',
374
  row.get('interaction', row.iloc[2] if len(row) > 2 else ''))))
375
-
376
- severity = classify_severity(description)
377
-
378
  if drug_1 and drug_2:
379
- records.append((drug_1, drug_2, description, severity))
380
 
381
  if records:
382
- query = """
383
- INSERT INTO drug_interactions
384
- (drug_1, drug_2, interaction_description, severity)
385
- VALUES (?, ?, ?, ?)
386
- """
387
- execute_many(query, records)
388
  total_records += len(records)
389
 
390
  if limit and total_records >= limit:
@@ -395,24 +350,19 @@ def import_drug_interactions(limit: int = None) -> int:
395
 
396
 
397
  def import_all_data(interactions_limit: int = None) -> dict:
398
- """Import all structured data into the database."""
399
  print(f"\n{'='*50}")
400
  print("Med-I-C Data Import")
401
  print(f"{'='*50}\n")
402
 
403
- # Initialize database
404
  init_database()
405
 
406
- # Clear existing data
407
  with get_connection() as conn:
408
- conn.execute("DELETE FROM eml_antibiotics")
409
- conn.execute("DELETE FROM atlas_susceptibility")
410
- conn.execute("DELETE FROM mic_breakpoints")
411
- conn.execute("DELETE FROM drug_interactions")
412
  conn.commit()
413
  print("Cleared existing data\n")
414
 
415
- # Import all data
416
  results = {
417
  "eml_antibiotics": import_eml_antibiotics(),
418
  "atlas_susceptibility": import_atlas_susceptibility(),
@@ -430,5 +380,4 @@ def import_all_data(interactions_limit: int = None) -> dict:
430
 
431
 
432
  if __name__ == "__main__":
433
- # Import with a limit on interactions for faster demo
434
  import_all_data(interactions_limit=50000)
 
1
  """Data import scripts for Med-I-C structured documents."""
2
 
3
  import pandas as pd
 
4
  from pathlib import Path
5
  from .database import (
6
  get_connection, init_database, execute_many,
 
9
 
10
 
11
  def safe_float(value):
12
+ """Convert value to float; return None if the value is NaN or non-numeric."""
13
  if pd.isna(value):
14
  return None
15
  try:
 
19
 
20
 
21
  def safe_int(value):
22
+ """Convert value to int via float; return None if the value is NaN or non-numeric."""
23
  if pd.isna(value):
24
  return None
25
  try:
 
28
  return None
29
 
30
 
31
+ def safe_str(value) -> str:
32
+ """Convert value to string; return empty string for None or NaN."""
33
+ if value is None or pd.isna(value):
34
+ return ''
35
+ return str(value)
36
+
37
+
38
  def classify_severity(description: str) -> str:
39
+ """
40
+ Classify drug interaction severity from the interaction description text.
41
+
42
+ Returns 'major', 'moderate', or 'minor' based on keyword presence.
43
+ Major keywords take precedence over moderate.
44
+ """
45
  if not description:
46
  return "unknown"
47
 
48
  desc_lower = description.lower()
49
 
 
50
  major_keywords = [
51
  "cardiotoxic", "nephrotoxic", "hepatotoxic", "neurotoxic",
52
  "fatal", "death", "severe", "contraindicated", "arrhythmia",
53
  "qt prolongation", "seizure", "bleeding", "hemorrhage",
54
+ "serotonin syndrome", "neuroleptic malignant",
55
  ]
 
 
56
  moderate_keywords = [
57
  "increase", "decrease", "reduce", "enhance", "inhibit",
58
  "metabolism", "concentration", "absorption", "excretion",
59
+ "therapeutic effect", "adverse effect", "toxicity",
60
  ]
61
 
62
+ if any(kw in desc_lower for kw in major_keywords):
63
+ return "major"
64
+ if any(kw in desc_lower for kw in moderate_keywords):
65
+ return "moderate"
 
 
 
 
66
  return "minor"
67
 
68
 
69
  def import_eml_antibiotics() -> int:
70
+ """Import WHO EML antibiotic classification data from the three AWaRe Excel files."""
71
  print("Importing EML antibiotic data...")
72
 
73
  eml_files = {
 
83
  continue
84
 
85
  try:
 
86
  import openpyxl
87
  wb = openpyxl.load_workbook(filepath, read_only=True)
88
  ws = wb.active
89
 
90
+ headers = [
91
+ str(cell.value).strip().lower().replace(' ', '_') if cell.value else f'col_{i}'
92
+ for i, cell in enumerate(ws[1])
93
+ ]
94
 
95
+ for row in ws.iter_rows(min_row=2, values_only=True):
 
96
  row_dict = dict(zip(headers, row))
 
97
  medicine = str(row_dict.get('medicine_name', row_dict.get('medicine', '')))
98
+ if not medicine or medicine in ('None', 'nan'):
99
  continue
100
 
 
 
 
 
 
101
  records.append((
102
  medicine,
103
  category,
 
110
  ))
111
 
112
  wb.close()
113
+ print(f" Loaded {sum(1 for r in records if r[1] == category)} from {category}")
114
 
115
  except Exception as e:
116
  print(f" Warning: Error reading {filepath}: {e}")
117
  continue
118
 
119
  if records:
120
+ execute_many(
121
+ """INSERT INTO eml_antibiotics
122
+ (medicine_name, who_category, eml_section, formulations,
123
+ indication, atc_codes, combined_with, status)
124
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
125
+ records,
126
+ )
127
  print(f" Imported {len(records)} EML antibiotic records total")
128
 
129
  return len(records)
 
139
  print(f" Warning: {filepath} not found, skipping...")
140
  return 0
141
 
 
142
  df_raw = pd.read_excel(filepath, sheet_name="Percent", header=None)
143
 
144
+ # Title row contains "Percentage Susceptibility from <Country>"
145
  region = "Unknown"
146
+ for _, row in df_raw.head(5).iterrows():
147
  cell = str(row.iloc[0]) if pd.notna(row.iloc[0]) else ""
148
  if "from" in cell.lower():
 
149
  parts = cell.split("from")
150
  if len(parts) > 1:
151
  region = parts[1].strip()
152
  break
153
 
154
+ # Locate the actual header row by finding "Antibacterial"
155
+ header_row = 4
156
  for idx, row in df_raw.head(10).iterrows():
157
  if any('Antibacterial' in str(v) for v in row.values if pd.notna(v)):
158
  header_row = idx
159
  break
160
 
 
161
  df = pd.read_excel(filepath, sheet_name="Percent", header=header_row)
 
 
162
  df.columns = [str(col).strip().lower().replace(' ', '_').replace('.', '') for col in df.columns]
163
 
164
  records = []
165
  for _, row in df.iterrows():
166
  antibiotic = str(row.get('antibacterial', ''))
 
 
167
  if not antibiotic or antibiotic == 'nan' or 'omitted' in antibiotic.lower():
168
  continue
169
  if 'in vitro' in antibiotic.lower() or 'table cells' in antibiotic.lower():
170
  continue
171
 
172
+ n_int = safe_int(row.get('n'))
173
+ s_float = safe_float(row.get('susc', row.get('susceptible')))
 
 
 
 
 
 
 
174
 
175
  if n_int is not None and s_float is not None:
176
  records.append((
177
+ "General",
178
+ "",
179
  antibiotic,
180
  s_float,
181
+ safe_float(row.get('int', row.get('intermediate'))),
182
+ safe_float(row.get('res', row.get('resistant'))),
183
  n_int,
184
+ 2024,
185
  region,
186
+ "ATLAS",
187
  ))
188
 
189
  if records:
190
+ execute_many(
191
+ """INSERT INTO atlas_susceptibility
192
+ (species, family, antibiotic, percent_susceptible,
193
+ percent_intermediate, percent_resistant, total_isolates,
194
+ year, region, source)
195
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
196
+ records,
197
+ )
198
  print(f" Imported {len(records)} ATLAS susceptibility records from {region}")
199
 
200
  return len(records)
201
 
202
 
203
  def import_mic_breakpoints() -> int:
204
+ """Import EUCAST MIC breakpoint tables from the Excel file."""
205
  print("Importing MIC breakpoint data...")
206
 
207
  filepath = DOCS_DIR / "mic_breakpoints" / "v_16.0__BreakpointTables.xlsx"
 
208
  if not filepath.exists():
209
  print(f" Warning: {filepath} not found, skipping...")
210
  return 0
211
 
 
212
  xl = pd.ExcelFile(filepath)
213
+ # These sheets contain metadata/guidance, not pathogen-specific breakpoints
 
214
  skip_sheets = {'Content', 'Changes', 'Notes', 'Guidance', 'Dosages',
215
  'Technical uncertainty', 'PK PD breakpoints', 'PK PD cutoffs'}
216
 
 
218
  for sheet_name in xl.sheet_names:
219
  if sheet_name in skip_sheets:
220
  continue
 
221
  try:
222
  df = pd.read_excel(filepath, sheet_name=sheet_name, header=None)
223
+ for _, row in df.iterrows():
 
 
 
 
 
224
  row_values = [str(v).strip() for v in row.values if pd.notna(v)]
225
+ if len(row_values) < 2:
226
+ continue
227
+
228
+ potential_antibiotic = row_values[0]
229
+ if any(kw in potential_antibiotic.lower() for kw in
230
+ ['antibiotic', 'agent', 'note', 'disk', 'mic', 'breakpoint']):
231
+ continue
232
 
233
+ # Extract numeric MIC values; strip inequality signs
234
+ mic_values = []
235
+ for v in row_values[1:]:
236
+ try:
237
+ mic_values.append(float(v.replace('≤', '').replace('>', '').replace('<', '').strip()))
238
+ except (ValueError, AttributeError):
239
+ pass
240
+
241
+ if len(mic_values) >= 2 and len(potential_antibiotic) > 2:
242
+ records.append((
243
+ sheet_name, # pathogen_group
244
+ potential_antibiotic,
245
+ None, # route
246
+ mic_values[0], # S breakpoint
247
+ mic_values[1], # R breakpoint
248
+ None, None, None, # disk S, disk R, notes
249
+ "16.0",
250
+ ))
 
 
 
 
 
 
 
 
 
 
 
251
  except Exception as e:
252
  print(f" Warning: Could not parse sheet '{sheet_name}': {e}")
253
  continue
254
 
255
  if records:
256
+ execute_many(
257
+ """INSERT INTO mic_breakpoints
258
+ (pathogen_group, antibiotic, route, mic_susceptible, mic_resistant,
259
+ disk_susceptible, disk_resistant, notes, eucast_version)
260
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
261
+ records,
262
+ )
263
  print(f" Imported {len(records)} MIC breakpoint records")
264
 
265
  return len(records)
 
272
 
273
  def _resolve_interactions_csv() -> Path | None:
274
  """
275
+ Find the drug interactions CSV file.
276
 
277
+ Checks in order:
278
+ 1. docs/drug_safety/db_drug_interactions.csv (local)
279
+ 2. /kaggle/input/drug-drug-interactions/ (Kaggle notebook with dataset attached)
280
+ 3. Kaggle API download (requires ~/.kaggle/kaggle.json)
281
  """
 
282
  if INTERACTIONS_CSV.exists():
283
  return INTERACTIONS_CSV
284
 
285
+ if KAGGLE_INPUT_DIR.exists():
286
+ for candidate in KAGGLE_INPUT_DIR.glob("*.csv"):
287
+ print(f" Found CSV in Kaggle input: {candidate}")
288
+ return candidate
289
 
 
290
  print(f" CSV not found — downloading from Kaggle dataset '{KAGGLE_DATASET}' ...")
291
  try:
292
+ import kaggle # noqa: F401 triggers credential check
293
+ import subprocess
294
  dest = INTERACTIONS_CSV.parent
295
  dest.mkdir(parents=True, exist_ok=True)
 
296
  result = subprocess.run(
297
+ ["kaggle", "datasets", "download", "-d", KAGGLE_DATASET, "--unzip", "-p", str(dest)],
 
298
  capture_output=True, text=True,
299
  )
300
  if result.returncode == 0:
 
301
  for f in dest.glob("*.csv"):
302
  print(f" Downloaded: {f.name}")
303
  return f
 
312
 
313
 
314
  def import_drug_interactions(limit: int = None) -> int:
315
+ """Import drug-drug interactions from the DDInter CSV (Kaggle dataset mghobashy/drug-drug-interactions)."""
316
  print("Importing drug interactions data...")
317
 
318
  filepath = _resolve_interactions_csv()
 
319
  if filepath is None:
320
  print(" Skipping drug interactions — CSV unavailable.")
321
  print(f" To fix: attach the Kaggle dataset '{KAGGLE_DATASET}' to your notebook,")
322
  print(" or set up ~/.kaggle/kaggle.json for API access.")
323
  return 0
324
 
 
 
325
  total_records = 0
326
+ for chunk in pd.read_csv(filepath, chunksize=10000):
 
 
327
  chunk.columns = [col.strip().lower().replace(' ', '_') for col in chunk.columns]
328
 
329
  records = []
 
332
  drug_2 = str(row.get('drug_2', row.get('drug2', row.iloc[1] if len(row) > 1 else '')))
333
  description = str(row.get('interaction_description', row.get('description',
334
  row.get('interaction', row.iloc[2] if len(row) > 2 else ''))))
 
 
 
335
  if drug_1 and drug_2:
336
+ records.append((drug_1, drug_2, description, classify_severity(description)))
337
 
338
  if records:
339
+ execute_many(
340
+ "INSERT INTO drug_interactions (drug_1, drug_2, interaction_description, severity) VALUES (?, ?, ?, ?)",
341
+ records,
342
+ )
 
 
343
  total_records += len(records)
344
 
345
  if limit and total_records >= limit:
 
350
 
351
 
352
  def import_all_data(interactions_limit: int = None) -> dict:
353
+ """Initialize the database and import all structured data sources."""
354
  print(f"\n{'='*50}")
355
  print("Med-I-C Data Import")
356
  print(f"{'='*50}\n")
357
 
 
358
  init_database()
359
 
 
360
  with get_connection() as conn:
361
+ for table in ("eml_antibiotics", "atlas_susceptibility", "mic_breakpoints", "drug_interactions"):
362
+ conn.execute(f"DELETE FROM {table}")
 
 
363
  conn.commit()
364
  print("Cleared existing data\n")
365
 
 
366
  results = {
367
  "eml_antibiotics": import_eml_antibiotics(),
368
  "atlas_susceptibility": import_atlas_susceptibility(),
 
380
 
381
 
382
  if __name__ == "__main__":
 
383
  import_all_data(interactions_limit=50000)
src/graph.py CHANGED
@@ -1,17 +1,13 @@
1
  """
2
- LangGraph Orchestrator for Med-I-C Multi-Agent System.
3
 
4
- Implements the infection lifecycle workflow with conditional routing:
 
5
 
6
- Stage 1 (Empirical - no lab results):
7
- Intake Historian -> Clinical Pharmacologist
8
-
9
- Stage 2 (Targeted - lab results available):
10
- Intake Historian -> Vision Specialist -> Trend Analyst -> Clinical Pharmacologist
11
  """
12
 
13
- from __future__ import annotations
14
-
15
  import logging
16
  from typing import Literal
17
 
@@ -28,189 +24,59 @@ from .state import InfectionState
28
  logger = logging.getLogger(__name__)
29
 
30
 
31
- # =============================================================================
32
- # NODE FUNCTIONS (Wrapper for agents)
33
- # =============================================================================
34
-
35
- def intake_historian_node(state: InfectionState) -> InfectionState:
36
- """Node 1: Run Intake Historian agent."""
37
- logger.info("Graph: Executing Intake Historian node")
38
- return run_intake_historian(state)
39
-
40
-
41
- def vision_specialist_node(state: InfectionState) -> InfectionState:
42
- """Node 2: Run Vision Specialist agent."""
43
- logger.info("Graph: Executing Vision Specialist node")
44
- return run_vision_specialist(state)
45
-
46
-
47
- def trend_analyst_node(state: InfectionState) -> InfectionState:
48
- """Node 3: Run Trend Analyst agent."""
49
- logger.info("Graph: Executing Trend Analyst node")
50
- return run_trend_analyst(state)
51
-
52
-
53
- def clinical_pharmacologist_node(state: InfectionState) -> InfectionState:
54
- """Node 4: Run Clinical Pharmacologist agent."""
55
- logger.info("Graph: Executing Clinical Pharmacologist node")
56
- return run_clinical_pharmacologist(state)
57
-
58
-
59
- # =============================================================================
60
- # CONDITIONAL ROUTING FUNCTIONS
61
- # =============================================================================
62
-
63
  def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
64
- """
65
- Determine routing after Intake Historian.
66
-
67
- Routes to Vision Specialist if:
68
- - stage is "targeted" AND
69
- - route_to_vision is True (i.e., we have lab data to process)
70
-
71
- Otherwise routes directly to Clinical Pharmacologist (empirical path).
72
- """
73
- stage = state.get("stage", "empirical")
74
- has_lab_data = state.get("route_to_vision", False)
75
-
76
- if stage == "targeted" and has_lab_data:
77
- logger.info("Graph: Routing to Vision Specialist (targeted path)")
78
  return "vision_specialist"
79
- else:
80
- logger.info("Graph: Routing to Clinical Pharmacologist (empirical path)")
81
- return "clinical_pharmacologist"
82
 
83
 
84
  def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
85
- """
86
- Determine routing after Vision Specialist.
87
-
88
- Routes to Trend Analyst if:
89
- - route_to_trend_analyst is True (i.e., we have MIC data to analyze)
90
-
91
- Otherwise skips to Clinical Pharmacologist.
92
- """
93
- should_analyze_trends = state.get("route_to_trend_analyst", False)
94
-
95
- if should_analyze_trends:
96
- logger.info("Graph: Routing to Trend Analyst")
97
  return "trend_analyst"
98
- else:
99
- logger.info("Graph: Skipping Trend Analyst, routing to Clinical Pharmacologist")
100
- return "clinical_pharmacologist"
101
-
102
 
103
- # =============================================================================
104
- # GRAPH CONSTRUCTION
105
- # =============================================================================
106
 
107
  def build_infection_graph() -> StateGraph:
108
- """
109
- Build the LangGraph StateGraph for the infection lifecycle workflow.
110
-
111
- Returns:
112
- Compiled StateGraph ready for execution
113
- """
114
- # Create the graph with InfectionState as the state schema
115
  graph = StateGraph(InfectionState)
116
 
117
- # Add nodes
118
- graph.add_node("intake_historian", intake_historian_node)
119
- graph.add_node("vision_specialist", vision_specialist_node)
120
- graph.add_node("trend_analyst", trend_analyst_node)
121
- graph.add_node("clinical_pharmacologist", clinical_pharmacologist_node)
122
 
123
- # Set entry point
124
  graph.set_entry_point("intake_historian")
125
 
126
- # Add conditional edges from intake_historian
127
  graph.add_conditional_edges(
128
  "intake_historian",
129
  route_after_intake,
130
- {
131
- "vision_specialist": "vision_specialist",
132
- "clinical_pharmacologist": "clinical_pharmacologist",
133
- }
134
  )
135
-
136
- # Add conditional edges from vision_specialist
137
  graph.add_conditional_edges(
138
  "vision_specialist",
139
  route_after_vision,
140
- {
141
- "trend_analyst": "trend_analyst",
142
- "clinical_pharmacologist": "clinical_pharmacologist",
143
- }
144
  )
145
 
146
- # Add edge from trend_analyst to clinical_pharmacologist
147
  graph.add_edge("trend_analyst", "clinical_pharmacologist")
148
-
149
- # Add edge from clinical_pharmacologist to END
150
  graph.add_edge("clinical_pharmacologist", END)
151
 
152
  return graph
153
 
154
 
155
- def compile_graph():
156
  """
157
- Build and compile the graph for execution.
158
 
159
- Returns:
160
- Compiled graph that can be invoked with .invoke(state)
161
  """
162
- graph = build_infection_graph()
163
- return graph.compile()
164
-
165
-
166
- # =============================================================================
167
- # EXECUTION HELPERS
168
- # =============================================================================
169
-
170
- def run_pipeline(
171
- patient_data: dict,
172
- labs_raw_text: str | None = None,
173
- ) -> InfectionState:
174
- """
175
- Run the full infection lifecycle pipeline.
176
-
177
- This is the main entry point for executing the multi-agent workflow.
178
-
179
- Args:
180
- patient_data: Dict containing patient information:
181
- - age_years: Patient age
182
- - weight_kg: Patient weight
183
- - sex: "male" or "female"
184
- - serum_creatinine_mg_dl: Serum creatinine (optional)
185
- - medications: List of current medications
186
- - allergies: List of allergies
187
- - comorbidities: List of comorbidities
188
- - infection_site: Site of infection
189
- - suspected_source: Suspected pathogen/source
190
-
191
- labs_raw_text: Raw text from lab report (if available).
192
- If provided, triggers targeted (Stage 2) pathway.
193
-
194
- Returns:
195
- Final InfectionState with recommendation
196
-
197
- Example:
198
- >>> state = run_pipeline(
199
- ... patient_data={
200
- ... "age_years": 65,
201
- ... "weight_kg": 70,
202
- ... "sex": "male",
203
- ... "serum_creatinine_mg_dl": 1.2,
204
- ... "medications": ["metformin", "lisinopril"],
205
- ... "allergies": ["penicillin"],
206
- ... "infection_site": "urinary",
207
- ... "suspected_source": "community UTI",
208
- ... },
209
- ... labs_raw_text="E. coli isolated. Ciprofloxacin MIC: 0.5 mg/L (S)"
210
- ... )
211
- >>> print(state["recommendation"]["primary_antibiotic"])
212
- """
213
- # Build initial state from patient data
214
  initial_state: InfectionState = {
215
  "age_years": patient_data.get("age_years"),
216
  "weight_kg": patient_data.get("weight_kg"),
@@ -224,59 +90,34 @@ def run_pipeline(
224
  "suspected_source": patient_data.get("suspected_source"),
225
  "country_or_region": patient_data.get("country_or_region"),
226
  "vitals": patient_data.get("vitals", {}),
 
227
  }
228
 
229
- # Add lab data if provided
230
  if labs_raw_text:
231
  initial_state["labs_raw_text"] = labs_raw_text
232
- initial_state["stage"] = "targeted"
233
- else:
234
- initial_state["stage"] = "empirical"
235
-
236
- # Compile and run the graph
237
- logger.info(f"Starting pipeline execution (stage: {initial_state['stage']})")
238
-
239
- compiled_graph = compile_graph()
240
- final_state = compiled_graph.invoke(initial_state)
241
-
242
- logger.info("Pipeline execution complete")
243
 
 
 
 
 
244
  return final_state
245
 
246
 
247
  def run_empirical_pipeline(patient_data: dict) -> InfectionState:
248
- """
249
- Run Stage 1 (Empirical) pipeline only.
250
-
251
- Shorthand for run_pipeline without lab data.
252
- """
253
- return run_pipeline(patient_data, labs_raw_text=None)
254
 
255
 
256
  def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
257
- """
258
- Run Stage 2 (Targeted) pipeline with lab data.
259
-
260
- Shorthand for run_pipeline with lab data.
261
- """
262
  return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
263
 
264
 
265
- # =============================================================================
266
- # VISUALIZATION (for debugging)
267
- # =============================================================================
268
-
269
  def get_graph_mermaid() -> str:
270
- """
271
- Get Mermaid diagram representation of the graph.
272
-
273
- Useful for documentation and debugging.
274
- """
275
- graph = build_infection_graph()
276
  try:
277
- return graph.compile().get_graph().draw_mermaid()
278
  except Exception:
279
- # Fallback: return manual diagram
280
  return """
281
  graph TD
282
  A[intake_historian] --> B{route_after_intake}
@@ -288,13 +129,3 @@ graph TD
288
  F --> E
289
  E --> G[END]
290
  """
291
-
292
-
293
- __all__ = [
294
- "build_infection_graph",
295
- "compile_graph",
296
- "run_pipeline",
297
- "run_empirical_pipeline",
298
- "run_targeted_pipeline",
299
- "get_graph_mermaid",
300
- ]
 
1
  """
2
+ LangGraph orchestrator for the infection lifecycle workflow.
3
 
4
+ Stage 1 (empirical - no lab results):
5
+ Intake Historian → Clinical Pharmacologist
6
 
7
+ Stage 2 (targeted - lab results available):
8
+ Intake Historian Vision Specialist → [Trend Analyst →] Clinical Pharmacologist
 
 
 
9
  """
10
 
 
 
11
  import logging
12
  from typing import Literal
13
 
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
28
+ """Route to Vision Specialist if we have lab text to parse; otherwise go straight to pharmacologist."""
29
+ if state.get("stage") == "targeted" and state.get("route_to_vision"):
30
+ logger.info("Graph: routing to Vision Specialist (targeted path)")
 
 
 
 
 
 
 
 
 
 
 
31
  return "vision_specialist"
32
+ logger.info("Graph: routing to Clinical Pharmacologist (empirical path)")
33
+ return "clinical_pharmacologist"
 
34
 
35
 
36
  def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
37
+ """Route to Trend Analyst if Vision Specialist extracted MIC values."""
38
+ if state.get("route_to_trend_analyst"):
39
+ logger.info("Graph: routing to Trend Analyst")
 
 
 
 
 
 
 
 
 
40
  return "trend_analyst"
41
+ logger.info("Graph: skipping Trend Analyst (no MIC data)")
42
+ return "clinical_pharmacologist"
 
 
43
 
 
 
 
44
 
45
  def build_infection_graph() -> StateGraph:
46
+ """Build and return the compiled LangGraph for the infection pipeline."""
 
 
 
 
 
 
47
  graph = StateGraph(InfectionState)
48
 
49
+ graph.add_node("intake_historian", run_intake_historian)
50
+ graph.add_node("vision_specialist", run_vision_specialist)
51
+ graph.add_node("trend_analyst", run_trend_analyst)
52
+ graph.add_node("clinical_pharmacologist", run_clinical_pharmacologist)
 
53
 
 
54
  graph.set_entry_point("intake_historian")
55
 
 
56
  graph.add_conditional_edges(
57
  "intake_historian",
58
  route_after_intake,
59
+ {"vision_specialist": "vision_specialist", "clinical_pharmacologist": "clinical_pharmacologist"},
 
 
 
60
  )
 
 
61
  graph.add_conditional_edges(
62
  "vision_specialist",
63
  route_after_vision,
64
+ {"trend_analyst": "trend_analyst", "clinical_pharmacologist": "clinical_pharmacologist"},
 
 
 
65
  )
66
 
 
67
  graph.add_edge("trend_analyst", "clinical_pharmacologist")
 
 
68
  graph.add_edge("clinical_pharmacologist", END)
69
 
70
  return graph
71
 
72
 
73
+ def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> InfectionState:
74
  """
75
+ Run the full infection pipeline and return the final state.
76
 
77
+ Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
78
+ Without it, only the empirical (Stage 1) pathway runs.
79
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  initial_state: InfectionState = {
81
  "age_years": patient_data.get("age_years"),
82
  "weight_kg": patient_data.get("weight_kg"),
 
90
  "suspected_source": patient_data.get("suspected_source"),
91
  "country_or_region": patient_data.get("country_or_region"),
92
  "vitals": patient_data.get("vitals", {}),
93
+ "stage": "targeted" if labs_raw_text else "empirical",
94
  }
95
 
 
96
  if labs_raw_text:
97
  initial_state["labs_raw_text"] = labs_raw_text
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ logger.info(f"Starting pipeline (stage: {initial_state['stage']})")
100
+ compiled = build_infection_graph().compile()
101
+ final_state = compiled.invoke(initial_state)
102
+ logger.info("Pipeline complete")
103
  return final_state
104
 
105
 
106
  def run_empirical_pipeline(patient_data: dict) -> InfectionState:
107
+ """Shorthand for run_pipeline without lab data (Stage 1)."""
108
+ return run_pipeline(patient_data)
 
 
 
 
109
 
110
 
111
  def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
112
+ """Shorthand for run_pipeline with lab data (Stage 2)."""
 
 
 
 
113
  return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
114
 
115
 
 
 
 
 
116
  def get_graph_mermaid() -> str:
117
+ """Return a Mermaid diagram of the graph (for documentation and debugging)."""
 
 
 
 
 
118
  try:
119
+ return build_infection_graph().compile().get_graph().draw_mermaid()
120
  except Exception:
 
121
  return """
122
  graph TD
123
  A[intake_historian] --> B{route_after_intake}
 
129
  F --> E
130
  E --> G[END]
131
  """
 
 
 
 
 
 
 
 
 
 
src/loader.py CHANGED
@@ -1,22 +1,17 @@
1
 
2
- from __future__ import annotations
3
-
4
  import logging
5
  from functools import lru_cache
6
- from typing import Any, Callable, Dict, Literal, Optional, Tuple
7
 
8
  from .config import get_settings
9
 
10
-
11
  logger = logging.getLogger(__name__)
12
 
13
  TextBackend = Literal["vertex", "local"]
14
  TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
15
 
16
 
17
- def _resolve_backend(
18
- requested: Optional[TextBackend],
19
- ) -> TextBackend:
20
  settings = get_settings()
21
  backend = requested or settings.default_backend # type: ignore[assignment]
22
  if backend == "vertex" and not settings.use_vertex:
@@ -27,23 +22,16 @@ def _resolve_backend(
27
 
28
  @lru_cache(maxsize=8)
29
  def _get_vertex_chat_model(model_name: TextModelName):
30
- """
31
- Lazily construct a Vertex AI chat model via langchain-google-vertexai.
32
-
33
- Returns an object with an .invoke(str) method; we wrap this in a simple
34
- callable for downstream use.
35
- """
36
-
37
  try:
38
  from langchain_google_vertexai import ChatVertexAI
39
- except Exception as exc: # pragma: no cover - import-time failure
40
  raise RuntimeError(
41
  "langchain-google-vertexai is not available; "
42
  "install it or switch MEDIC_DEFAULT_BACKEND=local."
43
  ) from exc
44
 
45
  settings = get_settings()
46
-
47
  if settings.vertex_project_id is None:
48
  raise RuntimeError(
49
  "MEDIC_VERTEX_PROJECT_ID is not set. "
@@ -56,40 +44,28 @@ def _get_vertex_chat_model(model_name: TextModelName):
56
  "txgemma_9b": settings.vertex_txgemma_9b_model,
57
  "txgemma_2b": settings.vertex_txgemma_2b_model,
58
  }
59
- model_id = model_id_map[model_name]
60
 
61
  llm = ChatVertexAI(
62
- model=model_id,
63
  project=settings.vertex_project_id,
64
  location=settings.vertex_location,
65
  temperature=0.2,
66
  )
67
 
68
  def _call(prompt: str, **kwargs: Any) -> str:
69
- """Thin wrapper returning plain text from ChatVertexAI."""
70
-
71
  result = llm.invoke(prompt, **kwargs)
72
- # langchain BaseMessage or plain string
73
- content = getattr(result, "content", result)
74
- return str(content)
75
 
76
  return _call
77
 
78
 
79
  @lru_cache(maxsize=8)
80
  def _get_local_causal_lm(model_name: TextModelName):
81
- """
82
- Lazily load a local transformers model for offline / Kaggle usage.
83
-
84
- Assumes model paths are provided via MEDIC_LOCAL_* env vars and that
85
- the appropriate model weights are available in the environment.
86
- """
87
-
88
  from transformers import AutoModelForCausalLM, AutoTokenizer
89
  import torch
90
 
91
  settings = get_settings()
92
-
93
  model_path_map: Dict[TextModelName, Optional[str]] = {
94
  "medgemma_4b": settings.local_medgemma_4b_model,
95
  "medgemma_27b": settings.local_medgemma_27b_model,
@@ -101,31 +77,19 @@ def _get_local_causal_lm(model_name: TextModelName):
101
  if not model_path:
102
  raise RuntimeError(
103
  f"No local model path configured for {model_name}. "
104
- f"Set MEDIC_LOCAL_*_MODEL or use the Vertex backend."
105
  )
106
 
107
- load_kwargs: Dict[str, Any] = {
108
- "device_map": "auto",
109
- }
110
-
111
- # Optional 4-bit quantization via bitsandbytes
112
- if get_settings().quantization == "4bit":
113
  load_kwargs["load_in_4bit"] = True
114
 
115
  tokenizer = AutoTokenizer.from_pretrained(model_path)
116
  model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
117
 
118
- def _call(
119
- prompt: str,
120
- max_new_tokens: int = 512,
121
- temperature: float = 0.2,
122
- **generate_kwargs: Any,
123
- ) -> str:
124
- inputs = tokenizer(prompt, return_tensors="pt")
125
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
126
-
127
  do_sample = temperature > 0
128
-
129
  with torch.no_grad():
130
  output_ids = model.generate(
131
  **inputs,
@@ -134,11 +98,9 @@ def _get_local_causal_lm(model_name: TextModelName):
134
  max_new_tokens=max_new_tokens,
135
  **generate_kwargs,
136
  )
137
-
138
- # Drop the prompt tokens and decode only the completion
139
- generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
140
- text = tokenizer.decode(generated_ids, skip_special_tokens=True)
141
- return text.strip()
142
 
143
  return _call
144
 
@@ -148,22 +110,9 @@ def get_text_model(
148
  model_name: TextModelName = "medgemma_4b",
149
  backend: Optional[TextBackend] = None,
150
  ) -> Callable[..., str]:
151
- """
152
- Return a cached text-generation callable.
153
-
154
- Example:
155
-
156
- from src.loader import get_text_model
157
- model = get_text_model("medgemma_4b")
158
- answer = model("Explain ESBL in simple terms.")
159
- """
160
-
161
- resolved_backend = _resolve_backend(backend)
162
-
163
- if resolved_backend == "vertex":
164
- return _get_vertex_chat_model(model_name)
165
- else:
166
- return _get_local_causal_lm(model_name)
167
 
168
 
169
  def run_inference(
@@ -174,28 +123,7 @@ def run_inference(
174
  temperature: float = 0.2,
175
  **kwargs: Any,
176
  ) -> str:
177
- """
178
- Convenience wrapper around `get_text_model`.
179
-
180
- This is the simplest entry point to use inside agents:
181
-
182
- from src.loader import run_inference
183
- text = run_inference(prompt, model_name="medgemma_4b")
184
- """
185
-
186
  model = get_text_model(model_name=model_name, backend=backend)
187
- return model(
188
- prompt,
189
- max_new_tokens=max_new_tokens,
190
- temperature=temperature,
191
- **kwargs,
192
- )
193
-
194
-
195
- __all__ = [
196
- "TextBackend",
197
- "TextModelName",
198
- "get_text_model",
199
- "run_inference",
200
- ]
201
 
 
1
 
 
 
2
  import logging
3
  from functools import lru_cache
4
+ from typing import Any, Callable, Dict, Literal, Optional
5
 
6
  from .config import get_settings
7
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
  TextBackend = Literal["vertex", "local"]
11
  TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
12
 
13
 
14
+ def _resolve_backend(requested: Optional[TextBackend]) -> TextBackend:
 
 
15
  settings = get_settings()
16
  backend = requested or settings.default_backend # type: ignore[assignment]
17
  if backend == "vertex" and not settings.use_vertex:
 
22
 
23
  @lru_cache(maxsize=8)
24
  def _get_vertex_chat_model(model_name: TextModelName):
25
+ """Load a Vertex AI chat model and return a callable that takes a prompt string."""
 
 
 
 
 
 
26
  try:
27
  from langchain_google_vertexai import ChatVertexAI
28
+ except Exception as exc:
29
  raise RuntimeError(
30
  "langchain-google-vertexai is not available; "
31
  "install it or switch MEDIC_DEFAULT_BACKEND=local."
32
  ) from exc
33
 
34
  settings = get_settings()
 
35
  if settings.vertex_project_id is None:
36
  raise RuntimeError(
37
  "MEDIC_VERTEX_PROJECT_ID is not set. "
 
44
  "txgemma_9b": settings.vertex_txgemma_9b_model,
45
  "txgemma_2b": settings.vertex_txgemma_2b_model,
46
  }
 
47
 
48
  llm = ChatVertexAI(
49
+ model=model_id_map[model_name],
50
  project=settings.vertex_project_id,
51
  location=settings.vertex_location,
52
  temperature=0.2,
53
  )
54
 
55
  def _call(prompt: str, **kwargs: Any) -> str:
 
 
56
  result = llm.invoke(prompt, **kwargs)
57
+ return str(getattr(result, "content", result))
 
 
58
 
59
  return _call
60
 
61
 
62
  @lru_cache(maxsize=8)
63
  def _get_local_causal_lm(model_name: TextModelName):
64
+ """Load a local HuggingFace causal LM and return a generation callable."""
 
 
 
 
 
 
65
  from transformers import AutoModelForCausalLM, AutoTokenizer
66
  import torch
67
 
68
  settings = get_settings()
 
69
  model_path_map: Dict[TextModelName, Optional[str]] = {
70
  "medgemma_4b": settings.local_medgemma_4b_model,
71
  "medgemma_27b": settings.local_medgemma_27b_model,
 
77
  if not model_path:
78
  raise RuntimeError(
79
  f"No local model path configured for {model_name}. "
80
+ "Set MEDIC_LOCAL_*_MODEL or use the Vertex backend."
81
  )
82
 
83
+ load_kwargs: Dict[str, Any] = {"device_map": "auto"}
84
+ if settings.quantization == "4bit":
 
 
 
 
85
  load_kwargs["load_in_4bit"] = True
86
 
87
  tokenizer = AutoTokenizer.from_pretrained(model_path)
88
  model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
89
 
90
+ def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
91
+ inputs = {k: v.to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
 
 
 
 
 
 
 
92
  do_sample = temperature > 0
 
93
  with torch.no_grad():
94
  output_ids = model.generate(
95
  **inputs,
 
98
  max_new_tokens=max_new_tokens,
99
  **generate_kwargs,
100
  )
101
+ # Decode only the newly generated tokens, not the input prompt
102
+ generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
103
+ return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
 
 
104
 
105
  return _call
106
 
 
110
  model_name: TextModelName = "medgemma_4b",
111
  backend: Optional[TextBackend] = None,
112
  ) -> Callable[..., str]:
113
+ """Return a cached callable for the requested model and backend."""
114
+ resolved = _resolve_backend(backend)
115
+ return _get_vertex_chat_model(model_name) if resolved == "vertex" else _get_local_causal_lm(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def run_inference(
 
123
  temperature: float = 0.2,
124
  **kwargs: Any,
125
  ) -> str:
126
+ """Run inference with the specified model. This is the primary entry point for agents."""
 
 
 
 
 
 
 
 
127
  model = get_text_model(model_name=model_name, backend=backend)
128
+ return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
src/prompts.py CHANGED
@@ -1,18 +1,7 @@
1
- """
2
- Prompt templates for Med-I-C multi-agent system.
3
-
4
- Each agent has a specific role in the infection lifecycle workflow:
5
- - Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
6
- - Agent 2: Vision Specialist - Extract structured data from lab reports (images/PDFs)
7
- - Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
8
- - Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
9
- """
10
 
11
- from __future__ import annotations
12
 
13
- # =============================================================================
14
- # AGENT 1: INTAKE HISTORIAN
15
- # =============================================================================
16
 
17
  INTAKE_HISTORIAN_SYSTEM = """You are an expert clinical intake specialist. Your role is to:
18
 
@@ -66,9 +55,7 @@ RAG CONTEXT (Relevant Guidelines):
66
  Provide your structured assessment following the system instructions."""
67
 
68
 
69
- # =============================================================================
70
- # AGENT 2: VISION SPECIALIST
71
- # =============================================================================
72
 
73
  VISION_SPECIALIST_SYSTEM = """You are an expert medical laboratory data extraction specialist. Your role is to:
74
 
@@ -131,9 +118,7 @@ Flag any critical findings that require urgent attention.
131
  Provide your structured extraction following the system instructions."""
132
 
133
 
134
- # =============================================================================
135
- # AGENT 3: TREND ANALYST
136
- # =============================================================================
137
 
138
  TREND_ANALYST_SYSTEM = """You are an expert antimicrobial resistance trend analyst. Your role is to:
139
 
@@ -195,9 +180,7 @@ Analyze the trend, calculate risk level, and provide recommendations.
195
  Follow the system instructions for output format."""
196
 
197
 
198
- # =============================================================================
199
- # AGENT 4: CLINICAL PHARMACOLOGIST
200
- # =============================================================================
201
 
202
  CLINICAL_PHARMACOLOGIST_SYSTEM = """You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:
203
 
@@ -291,9 +274,7 @@ Provide your final recommendation following the system instructions.
291
  Ensure all safety checks are performed and documented."""
292
 
293
 
294
- # =============================================================================
295
- # TXGEMMA SAFETY CHECKER (Supplementary)
296
- # =============================================================================
297
 
298
  TXGEMMA_SAFETY_PROMPT = """Evaluate the safety profile of the following antibiotic prescription:
299
 
@@ -315,9 +296,7 @@ Evaluate for:
315
  Provide a brief safety assessment (2-3 sentences) and a risk rating (LOW/MODERATE/HIGH)."""
316
 
317
 
318
- # =============================================================================
319
- # HELPER TEMPLATES
320
- # =============================================================================
321
 
322
  ERROR_RECOVERY_PROMPT = """The previous agent encountered an error or produced invalid output.
323
 
@@ -338,18 +317,3 @@ CLINICAL SCENARIO:
338
  - Local resistance patterns: {local_resistance}
339
 
340
  Recommend appropriate empirical therapy following WHO AWaRe principles."""
341
-
342
-
343
- __all__ = [
344
- "INTAKE_HISTORIAN_SYSTEM",
345
- "INTAKE_HISTORIAN_PROMPT",
346
- "VISION_SPECIALIST_SYSTEM",
347
- "VISION_SPECIALIST_PROMPT",
348
- "TREND_ANALYST_SYSTEM",
349
- "TREND_ANALYST_PROMPT",
350
- "CLINICAL_PHARMACOLOGIST_SYSTEM",
351
- "CLINICAL_PHARMACOLOGIST_PROMPT",
352
- "TXGEMMA_SAFETY_PROMPT",
353
- "ERROR_RECOVERY_PROMPT",
354
- "FALLBACK_EMPIRICAL_PROMPT",
355
- ]
 
1
+ """Prompt templates for each agent in the Med-I-C pipeline."""
 
 
 
 
 
 
 
 
2
 
 
3
 
4
+ # --- Agent 1: Intake Historian ---
 
 
5
 
6
  INTAKE_HISTORIAN_SYSTEM = """You are an expert clinical intake specialist. Your role is to:
7
 
 
55
  Provide your structured assessment following the system instructions."""
56
 
57
 
58
+ # --- Agent 2: Vision Specialist ---
 
 
59
 
60
  VISION_SPECIALIST_SYSTEM = """You are an expert medical laboratory data extraction specialist. Your role is to:
61
 
 
118
  Provide your structured extraction following the system instructions."""
119
 
120
 
121
+ # --- Agent 3: Trend Analyst ---
 
 
122
 
123
  TREND_ANALYST_SYSTEM = """You are an expert antimicrobial resistance trend analyst. Your role is to:
124
 
 
180
  Follow the system instructions for output format."""
181
 
182
 
183
+ # --- Agent 4: Clinical Pharmacologist ---
 
 
184
 
185
  CLINICAL_PHARMACOLOGIST_SYSTEM = """You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:
186
 
 
274
  Ensure all safety checks are performed and documented."""
275
 
276
 
277
+ # --- TxGemma safety check (supplementary, not primary decision-making) ---
 
 
278
 
279
  TXGEMMA_SAFETY_PROMPT = """Evaluate the safety profile of the following antibiotic prescription:
280
 
 
296
  Provide a brief safety assessment (2-3 sentences) and a risk rating (LOW/MODERATE/HIGH)."""
297
 
298
 
299
+ # --- Fallback templates ---
 
 
300
 
301
  ERROR_RECOVERY_PROMPT = """The previous agent encountered an error or produced invalid output.
302
 
 
317
  - Local resistance patterns: {local_resistance}
318
 
319
  Recommend appropriate empirical therapy following WHO AWaRe principles."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag.py CHANGED
@@ -1,116 +1,80 @@
1
  """
2
- RAG (Retrieval Augmented Generation) module for Med-I-C.
3
 
4
- Provides unified retrieval across multiple knowledge collections:
5
- - antibiotic_guidelines: WHO/IDSA treatment guidelines
6
- - mic_breakpoints: EUCAST/CLSI breakpoint tables
7
- - drug_safety: Drug interactions, warnings, contraindications
8
- - pathogen_resistance: Regional resistance patterns
9
  """
10
 
11
- from __future__ import annotations
12
-
13
  import logging
14
- from pathlib import Path
15
  from typing import Any, Dict, List, Optional
16
 
17
  from .config import get_settings
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
-
22
- # =============================================================================
23
- # CHROMA CLIENT & EMBEDDING SETUP
24
- # =============================================================================
25
-
26
  _chroma_client = None
27
  _embedding_function = None
28
 
29
 
30
  def get_chroma_client():
31
- """Get or create ChromaDB persistent client."""
32
  global _chroma_client
33
  if _chroma_client is None:
34
  import chromadb
35
-
36
- settings = get_settings()
37
- chroma_path = settings.chroma_db_dir
38
  chroma_path.mkdir(parents=True, exist_ok=True)
39
  _chroma_client = chromadb.PersistentClient(path=str(chroma_path))
40
  return _chroma_client
41
 
42
 
43
  def get_embedding_function():
44
- """Get or create the embedding function."""
45
  global _embedding_function
46
  if _embedding_function is None:
47
  from chromadb.utils import embedding_functions
48
-
49
- settings = get_settings()
50
  _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
51
- model_name=settings.embedding_model_name.split("/")[-1]
52
  )
53
  return _embedding_function
54
 
55
 
56
  def get_collection(name: str):
57
- """
58
- Get a ChromaDB collection by name.
59
-
60
- Returns None if collection doesn't exist.
61
- """
62
- client = get_chroma_client()
63
- ef = get_embedding_function()
64
-
65
  try:
66
- return client.get_collection(name=name, embedding_function=ef)
67
  except Exception:
68
  logger.warning(f"Collection '{name}' not found")
69
  return None
70
 
71
 
72
- # =============================================================================
73
- # COLLECTION-SPECIFIC RETRIEVERS
74
- # =============================================================================
75
-
76
  def search_antibiotic_guidelines(
77
  query: str,
78
  n_results: int = 5,
79
  pathogen_filter: Optional[str] = None,
80
  ) -> List[Dict[str, Any]]:
81
- """
82
- Search antibiotic treatment guidelines.
83
-
84
- Args:
85
- query: Search query
86
- n_results: Number of results to return
87
- pathogen_filter: Optional pathogen type filter (e.g., "ESBL-E", "CRE")
88
-
89
- Returns:
90
- List of relevant guideline excerpts with metadata
91
- """
92
  collection = get_collection("idsa_treatment_guidelines")
93
  if collection is None:
94
- logger.warning("idsa_treatment_guidelines collection not available")
95
  return []
96
-
97
- where_filter = None
98
- if pathogen_filter:
99
- where_filter = {"pathogen_type": pathogen_filter}
100
-
101
  try:
 
102
  results = collection.query(
103
  query_texts=[query],
104
  n_results=n_results,
105
- where=where_filter,
106
  include=["documents", "metadatas", "distances"],
107
  )
 
108
  except Exception as e:
109
  logger.error(f"Error querying guidelines: {e}")
110
  return []
111
 
112
- return _format_results(results)
113
-
114
 
115
  def search_mic_breakpoints(
116
  query: str,
@@ -118,79 +82,45 @@ def search_mic_breakpoints(
118
  organism: Optional[str] = None,
119
  antibiotic: Optional[str] = None,
120
  ) -> List[Dict[str, Any]]:
121
- """
122
- Search MIC breakpoint reference documentation.
123
-
124
- Args:
125
- query: Search query
126
- n_results: Number of results
127
- organism: Optional organism name filter
128
- antibiotic: Optional antibiotic name filter
129
-
130
- Returns:
131
- List of relevant breakpoint information
132
- """
133
  collection = get_collection("mic_reference_docs")
134
  if collection is None:
135
- logger.warning("mic_reference_docs collection not available")
136
  return []
137
-
138
- # Build query with organism/antibiotic context if provided
139
- enhanced_query = query
140
- if organism:
141
- enhanced_query = f"{organism} {enhanced_query}"
142
- if antibiotic:
143
- enhanced_query = f"{antibiotic} {enhanced_query}"
144
-
145
  try:
146
  results = collection.query(
147
  query_texts=[enhanced_query],
148
  n_results=n_results,
149
  include=["documents", "metadatas", "distances"],
150
  )
 
151
  except Exception as e:
152
  logger.error(f"Error querying breakpoints: {e}")
153
  return []
154
 
155
- return _format_results(results)
156
-
157
 
158
  def search_drug_safety(
159
  query: str,
160
  n_results: int = 5,
161
  drug_name: Optional[str] = None,
162
  ) -> List[Dict[str, Any]]:
163
- """
164
- Search drug safety information (interactions, warnings, contraindications).
165
-
166
- Args:
167
- query: Search query
168
- n_results: Number of results
169
- drug_name: Optional drug name to focus search
170
-
171
- Returns:
172
- List of relevant safety information
173
- """
174
  collection = get_collection("drug_safety")
175
  if collection is None:
176
- # Fallback: try existing collections
177
- logger.warning("drug_safety collection not available")
178
  return []
179
-
180
  enhanced_query = f"{drug_name} {query}" if drug_name else query
181
-
182
  try:
183
  results = collection.query(
184
  query_texts=[enhanced_query],
185
  n_results=n_results,
186
  include=["documents", "metadatas", "distances"],
187
  )
 
188
  except Exception as e:
189
  logger.error(f"Error querying drug safety: {e}")
190
  return []
191
 
192
- return _format_results(results)
193
-
194
 
195
  def search_resistance_patterns(
196
  query: str,
@@ -198,45 +128,22 @@ def search_resistance_patterns(
198
  organism: Optional[str] = None,
199
  region: Optional[str] = None,
200
  ) -> List[Dict[str, Any]]:
201
- """
202
- Search pathogen resistance pattern data.
203
-
204
- Args:
205
- query: Search query
206
- n_results: Number of results
207
- organism: Optional organism filter
208
- region: Optional geographic region filter
209
-
210
- Returns:
211
- List of relevant resistance data
212
- """
213
  collection = get_collection("pathogen_resistance")
214
  if collection is None:
215
- logger.warning("pathogen_resistance collection not available")
216
  return []
217
-
218
- enhanced_query = query
219
- if organism:
220
- enhanced_query = f"{organism} {enhanced_query}"
221
- if region:
222
- enhanced_query = f"{region} {enhanced_query}"
223
-
224
  try:
225
  results = collection.query(
226
  query_texts=[enhanced_query],
227
  n_results=n_results,
228
  include=["documents", "metadatas", "distances"],
229
  )
 
230
  except Exception as e:
231
  logger.error(f"Error querying resistance patterns: {e}")
232
  return []
233
 
234
- return _format_results(results)
235
-
236
-
237
- # =============================================================================
238
- # UNIFIED CONTEXT RETRIEVER
239
- # =============================================================================
240
 
241
  def get_context_for_agent(
242
  agent_name: str,
@@ -245,238 +152,104 @@ def get_context_for_agent(
245
  n_results: int = 3,
246
  ) -> str:
247
  """
248
- Get formatted RAG context string for a specific agent.
249
-
250
- This is the main entry point for agents to retrieve context.
251
-
252
- Args:
253
- agent_name: Name of the requesting agent
254
- query: The primary search query
255
- patient_context: Optional dict with patient-specific info
256
- n_results: Number of results per collection
257
 
258
- Returns:
259
- Formatted context string for injection into prompts
 
 
 
260
  """
261
- context_parts = []
262
- patient_context = patient_context or {}
263
 
264
  if agent_name == "intake_historian":
265
- # Get empirical therapy guidelines
266
- guidelines = search_antibiotic_guidelines(
267
- query=query,
268
- n_results=n_results,
269
- pathogen_filter=patient_context.get("pathogen_type"),
270
- )
271
  if guidelines:
272
- context_parts.append("RELEVANT TREATMENT GUIDELINES:")
273
  for g in guidelines:
274
- context_parts.append(f"- {g['content'][:500]}...")
275
- context_parts.append(f" [Source: {g.get('source', 'IDSA Guidelines')}]")
276
 
277
  elif agent_name == "vision_specialist":
278
- # Get MIC reference info for lab interpretation
279
- breakpoints = search_mic_breakpoints(
280
- query=query,
281
- n_results=n_results,
282
- organism=patient_context.get("organism"),
283
- antibiotic=patient_context.get("antibiotic"),
284
- )
285
  if breakpoints:
286
- context_parts.append("RELEVANT BREAKPOINT INFORMATION:")
287
  for b in breakpoints:
288
- context_parts.append(f"- {b['content'][:400]}...")
289
 
290
  elif agent_name == "trend_analyst":
291
- # Get breakpoints and resistance trends
292
  breakpoints = search_mic_breakpoints(
293
- query=f"breakpoint {patient_context.get('organism', '')} {patient_context.get('antibiotic', '')}",
294
- n_results=n_results,
295
- )
296
- resistance = search_resistance_patterns(
297
- query=query,
298
  n_results=n_results,
299
- organism=patient_context.get("organism"),
300
- region=patient_context.get("region"),
301
  )
302
-
303
  if breakpoints:
304
- context_parts.append("EUCAST BREAKPOINT DATA:")
305
  for b in breakpoints:
306
- context_parts.append(f"- {b['content'][:400]}...")
307
-
308
  if resistance:
309
- context_parts.append("\nRESISTANCE PATTERN DATA:")
310
  for r in resistance:
311
- context_parts.append(f"- {r['content'][:400]}...")
312
 
313
  elif agent_name == "clinical_pharmacologist":
314
- # Get comprehensive context for final recommendation
315
- guidelines = search_antibiotic_guidelines(
316
- query=query,
317
- n_results=n_results,
318
- )
319
- safety = search_drug_safety(
320
- query=query,
321
- n_results=n_results,
322
- drug_name=patient_context.get("proposed_antibiotic"),
323
- )
324
-
325
  if guidelines:
326
- context_parts.append("TREATMENT GUIDELINES:")
327
  for g in guidelines:
328
- context_parts.append(f"- {g['content'][:400]}...")
329
-
330
  if safety:
331
- context_parts.append("\nDRUG SAFETY INFORMATION:")
332
  for s in safety:
333
- context_parts.append(f"- {s['content'][:400]}...")
334
 
335
  else:
336
- # Generic retrieval
337
  guidelines = search_antibiotic_guidelines(query, n_results=n_results)
338
- if guidelines:
339
- for g in guidelines:
340
- context_parts.append(f"- {g['content'][:500]}...")
341
-
342
- if not context_parts:
343
- return "No relevant context found in knowledge base."
344
-
345
- return "\n".join(context_parts)
346
-
347
 
348
- def get_context_string(
349
- query: str,
350
- collections: Optional[List[str]] = None,
351
- n_results_per_collection: int = 3,
352
- **filters,
353
- ) -> str:
354
- """
355
- Get a combined context string from multiple collections.
356
 
357
- This is a simpler interface for general-purpose RAG retrieval.
358
-
359
- Args:
360
- query: Search query
361
- collections: List of collection names to search (defaults to all)
362
- n_results_per_collection: Results per collection
363
- **filters: Additional filters (organism, antibiotic, region, etc.)
364
-
365
- Returns:
366
- Combined context string
367
- """
368
- default_collections = [
369
- "idsa_treatment_guidelines",
370
- "mic_reference_docs",
371
- ]
372
- collections = collections or default_collections
373
-
374
- context_parts = []
375
-
376
- for collection_name in collections:
377
- if collection_name == "idsa_treatment_guidelines":
378
- results = search_antibiotic_guidelines(
379
- query,
380
- n_results=n_results_per_collection,
381
- pathogen_filter=filters.get("pathogen_type"),
382
- )
383
- elif collection_name == "mic_reference_docs":
384
- results = search_mic_breakpoints(
385
- query,
386
- n_results=n_results_per_collection,
387
- organism=filters.get("organism"),
388
- antibiotic=filters.get("antibiotic"),
389
- )
390
- elif collection_name == "drug_safety":
391
- results = search_drug_safety(
392
- query,
393
- n_results=n_results_per_collection,
394
- drug_name=filters.get("drug_name"),
395
- )
396
- elif collection_name == "pathogen_resistance":
397
- results = search_resistance_patterns(
398
- query,
399
- n_results=n_results_per_collection,
400
- organism=filters.get("organism"),
401
- region=filters.get("region"),
402
- )
403
- else:
404
- continue
405
-
406
- if results:
407
- context_parts.append(f"=== {collection_name.upper()} ===")
408
- for r in results:
409
- context_parts.append(r["content"])
410
- context_parts.append(f"[Relevance: {1 - r.get('distance', 0):.2f}]")
411
- context_parts.append("")
412
-
413
- return "\n".join(context_parts) if context_parts else "No relevant context found."
414
-
415
-
416
- # =============================================================================
417
- # HELPER FUNCTIONS
418
- # =============================================================================
419
 
420
  def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
421
- """Format ChromaDB query results into a standard format."""
422
  if not results or not results.get("documents"):
423
  return []
424
 
425
- formatted = []
426
  documents = results["documents"][0] if results["documents"] else []
427
  metadatas = results.get("metadatas", [[]])[0]
428
  distances = results.get("distances", [[]])[0]
429
 
430
- for i, doc in enumerate(documents):
431
- formatted.append({
432
  "content": doc,
433
  "metadata": metadatas[i] if i < len(metadatas) else {},
434
  "distance": distances[i] if i < len(distances) else None,
435
  "source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
436
  "relevance_score": 1 - (distances[i] if i < len(distances) else 0),
437
- })
438
-
439
- return formatted
440
 
441
 
442
  def list_available_collections() -> List[str]:
443
- """List all available ChromaDB collections."""
444
- client = get_chroma_client()
445
  try:
446
- collections = client.list_collections()
447
- return [c.name for c in collections]
448
  except Exception as e:
449
  logger.error(f"Error listing collections: {e}")
450
  return []
451
 
452
 
453
  def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
454
- """Get information about a specific collection."""
455
  collection = get_collection(name)
456
  if collection is None:
457
  return None
458
-
459
  try:
460
- return {
461
- "name": collection.name,
462
- "count": collection.count(),
463
- "metadata": collection.metadata,
464
- }
465
  except Exception as e:
466
  logger.error(f"Error getting collection info: {e}")
467
  return None
468
-
469
-
470
- __all__ = [
471
- "get_chroma_client",
472
- "get_embedding_function",
473
- "get_collection",
474
- "search_antibiotic_guidelines",
475
- "search_mic_breakpoints",
476
- "search_drug_safety",
477
- "search_resistance_patterns",
478
- "get_context_for_agent",
479
- "get_context_string",
480
- "list_available_collections",
481
- "get_collection_info",
482
- ]
 
1
  """
2
+ RAG module for Med-I-C.
3
 
4
+ Retrieves context from four ChromaDB collections:
5
+ - idsa_treatment_guidelines: IDSA 2024 AMR guidance
6
+ - mic_reference_docs: EUCAST v16.0 breakpoint tables
7
+ - drug_safety: Drug interactions and contraindications
8
+ - pathogen_resistance: ATLAS regional susceptibility data
9
  """
10
 
 
 
11
  import logging
 
12
  from typing import Any, Dict, List, Optional
13
 
14
  from .config import get_settings
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Module-level singletons; initialized lazily to avoid import-time side effects
 
 
 
 
19
  _chroma_client = None
20
  _embedding_function = None
21
 
22
 
23
  def get_chroma_client():
24
+ """Return the ChromaDB persistent client, creating it on first call."""
25
  global _chroma_client
26
  if _chroma_client is None:
27
  import chromadb
28
+ chroma_path = get_settings().chroma_db_dir
 
 
29
  chroma_path.mkdir(parents=True, exist_ok=True)
30
  _chroma_client = chromadb.PersistentClient(path=str(chroma_path))
31
  return _chroma_client
32
 
33
 
34
  def get_embedding_function():
35
+ """Return the SentenceTransformer embedding function, creating it on first call."""
36
  global _embedding_function
37
  if _embedding_function is None:
38
  from chromadb.utils import embedding_functions
39
+ # Use only the model short name (not the full HuggingFace path)
40
+ model_short_name = get_settings().embedding_model_name.split("/")[-1]
41
  _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
42
+ model_name=model_short_name
43
  )
44
  return _embedding_function
45
 
46
 
47
  def get_collection(name: str):
48
+ """Return a ChromaDB collection by name, or None if it does not exist."""
 
 
 
 
 
 
 
49
  try:
50
+ return get_chroma_client().get_collection(name=name, embedding_function=get_embedding_function())
51
  except Exception:
52
  logger.warning(f"Collection '{name}' not found")
53
  return None
54
 
55
 
 
 
 
 
56
  def search_antibiotic_guidelines(
57
  query: str,
58
  n_results: int = 5,
59
  pathogen_filter: Optional[str] = None,
60
  ) -> List[Dict[str, Any]]:
61
+ """Search the IDSA treatment guidelines collection."""
 
 
 
 
 
 
 
 
 
 
62
  collection = get_collection("idsa_treatment_guidelines")
63
  if collection is None:
 
64
  return []
 
 
 
 
 
65
  try:
66
+ where = {"pathogen_type": pathogen_filter} if pathogen_filter else None
67
  results = collection.query(
68
  query_texts=[query],
69
  n_results=n_results,
70
+ where=where,
71
  include=["documents", "metadatas", "distances"],
72
  )
73
+ return _format_results(results)
74
  except Exception as e:
75
  logger.error(f"Error querying guidelines: {e}")
76
  return []
77
 
 
 
78
 
79
  def search_mic_breakpoints(
80
  query: str,
 
82
  organism: Optional[str] = None,
83
  antibiotic: Optional[str] = None,
84
  ) -> List[Dict[str, Any]]:
85
+ """Search the EUCAST MIC breakpoint reference collection."""
 
 
 
 
 
 
 
 
 
 
 
86
  collection = get_collection("mic_reference_docs")
87
  if collection is None:
 
88
  return []
89
+ # Prepend organism/antibiotic to query to narrow semantic search
90
+ enhanced_query = " ".join(filter(None, [organism, antibiotic, query]))
 
 
 
 
 
 
91
  try:
92
  results = collection.query(
93
  query_texts=[enhanced_query],
94
  n_results=n_results,
95
  include=["documents", "metadatas", "distances"],
96
  )
97
+ return _format_results(results)
98
  except Exception as e:
99
  logger.error(f"Error querying breakpoints: {e}")
100
  return []
101
 
 
 
102
 
103
  def search_drug_safety(
104
  query: str,
105
  n_results: int = 5,
106
  drug_name: Optional[str] = None,
107
  ) -> List[Dict[str, Any]]:
108
+ """Search the drug safety collection (interactions, warnings, contraindications)."""
 
 
 
 
 
 
 
 
 
 
109
  collection = get_collection("drug_safety")
110
  if collection is None:
 
 
111
  return []
 
112
  enhanced_query = f"{drug_name} {query}" if drug_name else query
 
113
  try:
114
  results = collection.query(
115
  query_texts=[enhanced_query],
116
  n_results=n_results,
117
  include=["documents", "metadatas", "distances"],
118
  )
119
+ return _format_results(results)
120
  except Exception as e:
121
  logger.error(f"Error querying drug safety: {e}")
122
  return []
123
 
 
 
124
 
125
  def search_resistance_patterns(
126
  query: str,
 
128
  organism: Optional[str] = None,
129
  region: Optional[str] = None,
130
  ) -> List[Dict[str, Any]]:
131
+ """Search the ATLAS pathogen resistance collection."""
 
 
 
 
 
 
 
 
 
 
 
132
  collection = get_collection("pathogen_resistance")
133
  if collection is None:
 
134
  return []
135
+ enhanced_query = " ".join(filter(None, [region, organism, query]))
 
 
 
 
 
 
136
  try:
137
  results = collection.query(
138
  query_texts=[enhanced_query],
139
  n_results=n_results,
140
  include=["documents", "metadatas", "distances"],
141
  )
142
+ return _format_results(results)
143
  except Exception as e:
144
  logger.error(f"Error querying resistance patterns: {e}")
145
  return []
146
 
 
 
 
 
 
 
147
 
148
  def get_context_for_agent(
149
  agent_name: str,
 
152
  n_results: int = 3,
153
  ) -> str:
154
  """
155
+ Return a formatted context string for a specific agent.
 
 
 
 
 
 
 
 
156
 
157
+ Each agent draws from the collections most relevant to its task:
158
+ - intake_historian: IDSA guidelines
159
+ - vision_specialist: MIC breakpoints
160
+ - trend_analyst: MIC breakpoints + resistance patterns
161
+ - clinical_pharmacologist: guidelines + drug safety
162
  """
163
+ ctx = patient_context or {}
164
+ parts = []
165
 
166
  if agent_name == "intake_historian":
167
+ guidelines = search_antibiotic_guidelines(query, n_results=n_results, pathogen_filter=ctx.get("pathogen_type"))
 
 
 
 
 
168
  if guidelines:
169
+ parts.append("RELEVANT TREATMENT GUIDELINES:")
170
  for g in guidelines:
171
+ parts.append(f"- {g['content'][:500]}...")
172
+ parts.append(f" [Source: {g.get('source', 'IDSA Guidelines')}]")
173
 
174
  elif agent_name == "vision_specialist":
175
+ breakpoints = search_mic_breakpoints(query, n_results=n_results, organism=ctx.get("organism"), antibiotic=ctx.get("antibiotic"))
 
 
 
 
 
 
176
  if breakpoints:
177
+ parts.append("RELEVANT BREAKPOINT INFORMATION:")
178
  for b in breakpoints:
179
+ parts.append(f"- {b['content'][:400]}...")
180
 
181
  elif agent_name == "trend_analyst":
 
182
  breakpoints = search_mic_breakpoints(
183
+ f"breakpoint {ctx.get('organism', '')} {ctx.get('antibiotic', '')}",
 
 
 
 
184
  n_results=n_results,
 
 
185
  )
186
+ resistance = search_resistance_patterns(query, n_results=n_results, organism=ctx.get("organism"), region=ctx.get("region"))
187
  if breakpoints:
188
+ parts.append("EUCAST BREAKPOINT DATA:")
189
  for b in breakpoints:
190
+ parts.append(f"- {b['content'][:400]}...")
 
191
  if resistance:
192
+ parts.append("\nRESISTANCE PATTERN DATA:")
193
  for r in resistance:
194
+ parts.append(f"- {r['content'][:400]}...")
195
 
196
  elif agent_name == "clinical_pharmacologist":
197
+ guidelines = search_antibiotic_guidelines(query, n_results=n_results)
198
+ safety = search_drug_safety(query, n_results=n_results, drug_name=ctx.get("proposed_antibiotic"))
 
 
 
 
 
 
 
 
 
199
  if guidelines:
200
+ parts.append("TREATMENT GUIDELINES:")
201
  for g in guidelines:
202
+ parts.append(f"- {g['content'][:400]}...")
 
203
  if safety:
204
+ parts.append("\nDRUG SAFETY INFORMATION:")
205
  for s in safety:
206
+ parts.append(f"- {s['content'][:400]}...")
207
 
208
  else:
 
209
  guidelines = search_antibiotic_guidelines(query, n_results=n_results)
210
+ for g in guidelines:
211
+ parts.append(f"- {g['content'][:500]}...")
 
 
 
 
 
 
 
212
 
213
+ return "\n".join(parts) if parts else "No relevant context found in knowledge base."
 
 
 
 
 
 
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
217
+ """Flatten ChromaDB query results into a list of dicts."""
218
  if not results or not results.get("documents"):
219
  return []
220
 
 
221
  documents = results["documents"][0] if results["documents"] else []
222
  metadatas = results.get("metadatas", [[]])[0]
223
  distances = results.get("distances", [[]])[0]
224
 
225
+ return [
226
+ {
227
  "content": doc,
228
  "metadata": metadatas[i] if i < len(metadatas) else {},
229
  "distance": distances[i] if i < len(distances) else None,
230
  "source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
231
  "relevance_score": 1 - (distances[i] if i < len(distances) else 0),
232
+ }
233
+ for i, doc in enumerate(documents)
234
+ ]
235
 
236
 
237
  def list_available_collections() -> List[str]:
238
+ """Return names of all ChromaDB collections that exist."""
 
239
  try:
240
+ return [c.name for c in get_chroma_client().list_collections()]
 
241
  except Exception as e:
242
  logger.error(f"Error listing collections: {e}")
243
  return []
244
 
245
 
246
  def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
247
+ """Return count and metadata for a collection, or None if it does not exist."""
248
  collection = get_collection(name)
249
  if collection is None:
250
  return None
 
251
  try:
252
+ return {"name": collection.name, "count": collection.count(), "metadata": collection.metadata}
 
 
 
 
253
  except Exception as e:
254
  logger.error(f"Error getting collection info: {e}")
255
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/state.py CHANGED
@@ -1,12 +1,9 @@
1
 
2
- from __future__ import annotations
3
-
4
  from typing import Dict, List, Literal, NotRequired, Optional, TypedDict
5
 
6
 
7
  class LabResult(TypedDict, total=False):
8
- """Structured representation of a single lab value."""
9
-
10
  name: str
11
  value: str
12
  unit: NotRequired[Optional[str]]
@@ -15,21 +12,19 @@ class LabResult(TypedDict, total=False):
15
 
16
 
17
  class MICDatum(TypedDict, total=False):
18
- """Single MIC measurement for a bugdrug pair."""
19
-
20
  organism: str
21
  antibiotic: str
22
  mic_value: str
23
  mic_unit: NotRequired[Optional[str]]
24
  interpretation: NotRequired[Optional[Literal["S", "I", "R"]]]
25
- breakpoint_source: NotRequired[Optional[str]] # e.g. EUCAST v16.0
26
  year: NotRequired[Optional[int]]
27
- site: NotRequired[Optional[str]] # e.g. blood, urine
28
 
29
 
30
  class Recommendation(TypedDict, total=False):
31
  """Final clinical recommendation assembled by Agent 4."""
32
-
33
  primary_antibiotic: Optional[str]
34
  backup_antibiotic: NotRequired[Optional[str]]
35
  dose: Optional[str]
@@ -43,24 +38,19 @@ class Recommendation(TypedDict, total=False):
43
 
44
  class InfectionState(TypedDict, total=False):
45
  """
46
- Global LangGraph state for the Med-I-C pipeline.
47
 
48
- All agents read from and write back to this object.
49
- Most keys are optional to keep the schema flexible across stages.
50
  """
51
 
52
- # ------------------------------------------------------------------
53
  # Patient identity & demographics
54
- # ------------------------------------------------------------------
55
  patient_id: NotRequired[Optional[str]]
56
  age_years: NotRequired[Optional[float]]
57
  sex: NotRequired[Optional[Literal["male", "female", "other", "unknown"]]]
58
  weight_kg: NotRequired[Optional[float]]
59
  height_cm: NotRequired[Optional[float]]
60
 
61
- # ------------------------------------------------------------------
62
  # Clinical context
63
- # ------------------------------------------------------------------
64
  suspected_source: NotRequired[Optional[str]] # e.g. "community UTI"
65
  comorbidities: NotRequired[List[str]]
66
  medications: NotRequired[List[str]]
@@ -68,58 +58,36 @@ class InfectionState(TypedDict, total=False):
68
  infection_site: NotRequired[Optional[str]]
69
  country_or_region: NotRequired[Optional[str]]
70
 
71
- # ------------------------------------------------------------------
72
- # Renal function / vitals
73
- # ------------------------------------------------------------------
74
  serum_creatinine_mg_dl: NotRequired[Optional[float]]
75
  creatinine_clearance_ml_min: NotRequired[Optional[float]]
76
  vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
77
 
78
- # ------------------------------------------------------------------
79
  # Lab data & MICs
80
- # ------------------------------------------------------------------
81
- labs_raw_text: NotRequired[Optional[str]] # raw OCR / PDF text
82
  labs_parsed: NotRequired[List[LabResult]]
83
-
84
  mic_data: NotRequired[List[MICDatum]]
85
  mic_trend_summary: NotRequired[Optional[str]]
86
 
87
- # ------------------------------------------------------------------
88
- # Stage / routing metadata
89
- # ------------------------------------------------------------------
90
  stage: NotRequired[Literal["empirical", "targeted"]]
91
  route_to_vision: NotRequired[bool]
92
  route_to_trend_analyst: NotRequired[bool]
93
 
94
- # ------------------------------------------------------------------
95
  # Agent outputs
96
- # ------------------------------------------------------------------
97
- intake_notes: NotRequired[Optional[str]] # Agent 1
98
- vision_notes: NotRequired[Optional[str]] # Agent 2
99
- trend_notes: NotRequired[Optional[str]] # Agent 3
100
- pharmacology_notes: NotRequired[Optional[str]] # Agent 4
101
-
102
  recommendation: NotRequired[Optional[Recommendation]]
103
 
104
- # ------------------------------------------------------------------
105
- # RAG / context + safety
106
- # ------------------------------------------------------------------
107
  rag_context: NotRequired[Optional[str]]
108
  guideline_sources: NotRequired[List[str]]
109
  breakpoint_sources: NotRequired[List[str]]
110
  safety_warnings: NotRequired[List[str]]
111
 
112
- # ------------------------------------------------------------------
113
- # Diagnostics / debugging
114
- # ------------------------------------------------------------------
115
  errors: NotRequired[List[str]]
116
  debug_log: NotRequired[List[str]]
117
 
118
-
119
- __all__ = [
120
- "LabResult",
121
- "MICDatum",
122
- "Recommendation",
123
- "InfectionState",
124
- ]
125
-
 
1
 
 
 
2
  from typing import Dict, List, Literal, NotRequired, Optional, TypedDict
3
 
4
 
5
  class LabResult(TypedDict, total=False):
6
+ """A single lab value with optional reference range and flag."""
 
7
  name: str
8
  value: str
9
  unit: NotRequired[Optional[str]]
 
12
 
13
 
14
  class MICDatum(TypedDict, total=False):
15
+ """A single MIC measurement for one organismantibiotic pair."""
 
16
  organism: str
17
  antibiotic: str
18
  mic_value: str
19
  mic_unit: NotRequired[Optional[str]]
20
  interpretation: NotRequired[Optional[Literal["S", "I", "R"]]]
21
+ breakpoint_source: NotRequired[Optional[str]] # e.g. "EUCAST v16.0"
22
  year: NotRequired[Optional[int]]
23
+ site: NotRequired[Optional[str]] # e.g. "blood", "urine"
24
 
25
 
26
  class Recommendation(TypedDict, total=False):
27
  """Final clinical recommendation assembled by Agent 4."""
 
28
  primary_antibiotic: Optional[str]
29
  backup_antibiotic: NotRequired[Optional[str]]
30
  dose: Optional[str]
 
38
 
39
  class InfectionState(TypedDict, total=False):
40
  """
41
+ Shared state object passed between all agents in the pipeline.
42
 
43
+ All keys are optional so each agent only needs to populate its own outputs.
 
44
  """
45
 
 
46
  # Patient identity & demographics
 
47
  patient_id: NotRequired[Optional[str]]
48
  age_years: NotRequired[Optional[float]]
49
  sex: NotRequired[Optional[Literal["male", "female", "other", "unknown"]]]
50
  weight_kg: NotRequired[Optional[float]]
51
  height_cm: NotRequired[Optional[float]]
52
 
 
53
  # Clinical context
 
54
  suspected_source: NotRequired[Optional[str]] # e.g. "community UTI"
55
  comorbidities: NotRequired[List[str]]
56
  medications: NotRequired[List[str]]
 
58
  infection_site: NotRequired[Optional[str]]
59
  country_or_region: NotRequired[Optional[str]]
60
 
61
+ # Renal function & vitals
 
 
62
  serum_creatinine_mg_dl: NotRequired[Optional[float]]
63
  creatinine_clearance_ml_min: NotRequired[Optional[float]]
64
  vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
65
 
 
66
  # Lab data & MICs
67
+ labs_raw_text: NotRequired[Optional[str]] # raw OCR or PDF text
 
68
  labs_parsed: NotRequired[List[LabResult]]
 
69
  mic_data: NotRequired[List[MICDatum]]
70
  mic_trend_summary: NotRequired[Optional[str]]
71
 
72
+ # Routing flags set by agents
 
 
73
  stage: NotRequired[Literal["empirical", "targeted"]]
74
  route_to_vision: NotRequired[bool]
75
  route_to_trend_analyst: NotRequired[bool]
76
 
 
77
  # Agent outputs
78
+ intake_notes: NotRequired[Optional[str]] # Agent 1
79
+ vision_notes: NotRequired[Optional[str]] # Agent 2
80
+ trend_notes: NotRequired[Optional[str]] # Agent 3
81
+ pharmacology_notes: NotRequired[Optional[str]] # Agent 4
 
 
82
  recommendation: NotRequired[Optional[Recommendation]]
83
 
84
+ # RAG context & safety
 
 
85
  rag_context: NotRequired[Optional[str]]
86
  guideline_sources: NotRequired[List[str]]
87
  breakpoint_sources: NotRequired[List[str]]
88
  safety_warnings: NotRequired[List[str]]
89
 
90
+ # Diagnostics
 
 
91
  errors: NotRequired[List[str]]
92
  debug_log: NotRequired[List[str]]
93
 
 
 
 
 
 
 
 
 
src/tools/rag_tools.py CHANGED
@@ -1,6 +1,5 @@
1
  """RAG tools for querying clinical guidelines via ChromaDB."""
2
 
3
- from typing import Optional
4
  from src.db.vector_store import search_guidelines, search_mic_reference
5
 
6
 
 
1
  """RAG tools for querying clinical guidelines via ChromaDB."""
2
 
 
3
  from src.db.vector_store import search_guidelines, search_mic_reference
4
 
5
 
src/utils.py CHANGED
@@ -1,23 +1,19 @@
1
  """
2
- Utility functions for Med-I-C multi-agent system.
3
 
4
- Includes:
5
- - Creatinine Clearance (CrCl) calculator
6
  - MIC trend analysis and creep detection
7
  - Prescription card formatter
8
- - Data validation helpers
9
  """
10
 
11
- from __future__ import annotations
12
-
13
  import json
14
  import math
 
15
  from typing import Any, Dict, List, Literal, Optional, Tuple
16
 
17
 
18
- # =============================================================================
19
- # CREATININE CLEARANCE CALCULATOR
20
- # =============================================================================
21
 
22
  def calculate_crcl(
23
  age_years: float,
@@ -28,40 +24,25 @@ def calculate_crcl(
28
  height_cm: Optional[float] = None,
29
  ) -> float:
30
  """
31
- Calculate Creatinine Clearance using the Cockcroft-Gault equation.
32
-
33
- Formula:
34
- CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
35
 
36
- Args:
37
- age_years: Patient age in years
38
- weight_kg: Actual body weight in kg
39
- serum_creatinine_mg_dl: Serum creatinine in mg/dL
40
- sex: Patient sex ("male" or "female")
41
- use_ibw: If True, use Ideal Body Weight instead of actual weight
42
- height_cm: Height in cm (required if use_ibw=True)
43
 
44
- Returns:
45
- Estimated CrCl in mL/min
 
46
  """
47
  if serum_creatinine_mg_dl <= 0:
48
  raise ValueError("Serum creatinine must be positive")
49
-
50
  if age_years <= 0 or weight_kg <= 0:
51
  raise ValueError("Age and weight must be positive")
52
 
53
- # Calculate weight to use
54
  weight = weight_kg
55
  if use_ibw and height_cm:
56
- weight = calculate_ibw(height_cm, sex)
57
- # Use adjusted body weight if actual weight > IBW
58
- if weight_kg > weight * 1.3:
59
- weight = calculate_adjusted_bw(weight, weight_kg)
60
 
61
- # Cockcroft-Gault equation
62
  crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
63
-
64
- # Apply sex factor
65
  if sex == "female":
66
  crcl *= 0.85
67
 
@@ -70,42 +51,27 @@ def calculate_crcl(
70
 
71
  def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
72
  """
73
- Calculate Ideal Body Weight using the Devine formula.
74
-
75
- Args:
76
- height_cm: Height in centimeters
77
- sex: Patient sex
78
 
79
- Returns:
80
- Ideal body weight in kg
81
  """
82
- height_inches = height_cm / 2.54
83
- height_over_60 = max(0, height_inches - 60)
84
-
85
- if sex == "male":
86
- ibw = 50 + 2.3 * height_over_60
87
- else:
88
- ibw = 45.5 + 2.3 * height_over_60
89
-
90
- return round(ibw, 1)
91
 
92
 
93
  def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
94
  """
95
- Calculate Adjusted Body Weight for obese patients.
96
 
97
- Formula: AdjBW = IBW + 0.4 × (Actual - IBW)
98
  """
99
  return round(ibw + 0.4 * (actual_weight - ibw), 1)
100
 
101
 
102
  def get_renal_dose_category(crcl: float) -> str:
103
- """
104
- Categorize renal function for dosing purposes.
105
-
106
- Returns:
107
- Renal function category
108
- """
109
  if crcl >= 90:
110
  return "normal"
111
  elif crcl >= 60:
@@ -118,9 +84,7 @@ def get_renal_dose_category(crcl: float) -> str:
118
  return "esrd"
119
 
120
 
121
- # =============================================================================
122
- # MIC TREND ANALYSIS
123
- # =============================================================================
124
 
125
  def calculate_mic_trend(
126
  mic_values: List[Dict[str, Any]],
@@ -128,15 +92,11 @@ def calculate_mic_trend(
128
  resistant_breakpoint: Optional[float] = None,
129
  ) -> Dict[str, Any]:
130
  """
131
- Analyze MIC trend over time and detect MIC creep.
132
 
133
- Args:
134
- mic_values: List of dicts with 'date' and 'mic_value' keys
135
- susceptible_breakpoint: S breakpoint (optional)
136
- resistant_breakpoint: R breakpoint (optional)
137
-
138
- Returns:
139
- Dict with trend analysis results
140
  """
141
  if len(mic_values) < 2:
142
  return {
@@ -145,49 +105,28 @@ def calculate_mic_trend(
145
  "alert": "Need at least 2 MIC values for trend analysis",
146
  }
147
 
148
- # Extract MIC values
149
  mics = [float(v["mic_value"]) for v in mic_values]
150
-
151
  baseline_mic = mics[0]
152
  current_mic = mics[-1]
 
153
 
154
- # Calculate fold change
155
- if baseline_mic > 0:
156
- fold_change = current_mic / baseline_mic
157
- else:
158
- fold_change = float("inf")
159
-
160
- # Calculate trend
161
  if len(mics) >= 3:
162
- # Linear regression slope
163
  n = len(mics)
164
  x_mean = (n - 1) / 2
165
  y_mean = sum(mics) / n
166
  numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
167
  denominator = sum((i - x_mean) ** 2 for i in range(n))
168
  slope = numerator / denominator if denominator != 0 else 0
169
-
170
- if slope > 0.5:
171
- trend = "increasing"
172
- elif slope < -0.5:
173
- trend = "decreasing"
174
- else:
175
- trend = "stable"
176
  else:
177
- if current_mic > baseline_mic * 1.5:
178
- trend = "increasing"
179
- elif current_mic < baseline_mic * 0.67:
180
- trend = "decreasing"
181
- else:
182
- trend = "stable"
183
-
184
- # Calculate resistance velocity (fold change per time point)
185
  velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
186
 
187
- # Determine risk level
188
  risk_level, alert = _assess_mic_risk(
189
  current_mic, baseline_mic, fold_change, trend,
190
- susceptible_breakpoint, resistant_breakpoint
191
  )
192
 
193
  return {
@@ -211,51 +150,39 @@ def _assess_mic_risk(
211
  r_breakpoint: Optional[float],
212
  ) -> Tuple[str, str]:
213
  """
214
- Assess risk level based on MIC trends and breakpoints.
215
 
216
- Returns:
217
- Tuple of (risk_level, alert_message)
218
  """
219
- # If we have breakpoints, use them for risk assessment
220
  if s_breakpoint is not None and r_breakpoint is not None:
221
  margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
222
 
223
  if current_mic > r_breakpoint:
224
  return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
225
-
226
  if current_mic > s_breakpoint:
227
  return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
228
-
229
  if margin < 2:
230
  if trend == "increasing":
231
  return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
232
- else:
233
- return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
234
-
235
  if margin < 4:
236
  if trend == "increasing":
237
  return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
238
- else:
239
- return "LOW", "MIC stable with adequate margin to breakpoint."
240
-
241
  return "LOW", "MIC well below breakpoint with good safety margin."
242
 
243
- # Without breakpoints, use fold change and trend
244
  if fold_change >= 8:
245
  return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
246
-
247
  if fold_change >= 4:
248
  return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
249
-
250
  if fold_change >= 2:
251
  if trend == "increasing":
252
  return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
253
- else:
254
- return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
255
-
256
  if trend == "increasing":
257
  return "MODERATE", "MIC showing upward trend. Continue monitoring."
258
-
259
  return "LOW", "MIC stable or decreasing. Current therapy appropriate."
260
 
261
 
@@ -268,58 +195,37 @@ def detect_mic_creep(
268
  """
269
  Detect MIC creep for a specific organism-antibiotic pair.
270
 
271
- Args:
272
- organism: Pathogen name
273
- antibiotic: Antibiotic name
274
- mic_history: Historical MIC values with dates
275
- breakpoints: Dict with 'susceptible' and 'resistant' keys
276
-
277
- Returns:
278
- Comprehensive MIC creep analysis
279
  """
280
- trend_analysis = calculate_mic_trend(
281
  mic_history,
282
  susceptible_breakpoint=breakpoints.get("susceptible"),
283
  resistant_breakpoint=breakpoints.get("resistant"),
284
  )
285
 
286
- # Add organism/antibiotic context
287
- trend_analysis["organism"] = organism
288
- trend_analysis["antibiotic"] = antibiotic
289
- trend_analysis["breakpoint_susceptible"] = breakpoints.get("susceptible")
290
- trend_analysis["breakpoint_resistant"] = breakpoints.get("resistant")
291
 
292
- # Calculate time to resistance estimate
293
- if trend_analysis["trend"] == "increasing" and trend_analysis["velocity"] > 1.0:
294
- current = trend_analysis["current_mic"]
295
  s_bp = breakpoints.get("susceptible")
296
  if s_bp and current < s_bp:
297
- # Estimate doublings needed to reach breakpoint
298
  doublings_needed = math.log2(s_bp / current) if current > 0 else 0
299
- # Estimate time based on velocity
300
- if trend_analysis["velocity"] > 1.0:
301
- log_velocity = math.log(trend_analysis["velocity"]) / math.log(2)
302
- if log_velocity > 0:
303
- time_estimate = doublings_needed / log_velocity
304
- trend_analysis["estimated_readings_to_resistance"] = round(time_estimate, 1)
305
 
306
- return trend_analysis
307
 
308
 
309
- # =============================================================================
310
- # PRESCRIPTION FORMATTER
311
- # =============================================================================
312
 
313
  def format_prescription_card(recommendation: Dict[str, Any]) -> str:
314
- """
315
- Format a recommendation into a readable prescription card.
316
-
317
- Args:
318
- recommendation: Dict with recommendation details
319
-
320
- Returns:
321
- Formatted prescription card as string
322
- """
323
  lines = []
324
  lines.append("=" * 50)
325
  lines.append("ANTIBIOTIC PRESCRIPTION")
@@ -336,14 +242,12 @@ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
336
  if primary.get("aware_category"):
337
  lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
338
 
339
- # Dose adjustments
340
  adjustments = recommendation.get("dose_adjustments", {})
341
  if adjustments.get("renal") and adjustments["renal"] != "None needed":
342
  lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
343
  if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
344
  lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
345
 
346
- # Safety alerts
347
  alerts = recommendation.get("safety_alerts", [])
348
  if alerts:
349
  lines.append("\n" + "-" * 50)
@@ -353,7 +257,6 @@ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
353
  marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
354
  lines.append(f" {marker} {alert.get('message', '')}")
355
 
356
- # Monitoring
357
  monitoring = recommendation.get("monitoring_parameters", [])
358
  if monitoring:
359
  lines.append("\n" + "-" * 50)
@@ -361,47 +264,33 @@ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
361
  for param in monitoring:
362
  lines.append(f" - {param}")
363
 
364
- # Rationale
365
  if recommendation.get("rationale"):
366
  lines.append("\n" + "-" * 50)
367
  lines.append("RATIONALE:")
368
  lines.append(f" {recommendation['rationale']}")
369
 
370
  lines.append("\n" + "=" * 50)
371
-
372
  return "\n".join(lines)
373
 
374
 
375
- # =============================================================================
376
- # JSON PARSING HELPERS
377
- # =============================================================================
378
 
379
  def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
380
  """
381
- Safely parse JSON from agent output, handling common issues.
382
 
383
- Attempts to extract JSON from text that may contain markdown code blocks
384
- or other formatting.
385
  """
386
  if not text:
387
  return None
388
 
389
- # Try direct parse first
390
  try:
391
  return json.loads(text)
392
  except json.JSONDecodeError:
393
  pass
394
 
395
- # Try to extract JSON from markdown code block
396
- import re
397
-
398
- json_patterns = [
399
- r"```json\s*\n?(.*?)\n?```", # ```json ... ```
400
- r"```\s*\n?(.*?)\n?```", # ``` ... ```
401
- r"\{[\s\S]*\}", # Raw JSON object
402
- ]
403
-
404
- for pattern in json_patterns:
405
  match = re.search(pattern, text, re.DOTALL)
406
  if match:
407
  try:
@@ -414,29 +303,15 @@ def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
414
 
415
 
416
  def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
417
- """
418
- Validate that agent output contains required fields.
419
-
420
- Args:
421
- output: Agent output dict
422
- required_fields: List of required field names
423
-
424
- Returns:
425
- Tuple of (is_valid, list_of_missing_fields)
426
- """
427
- missing = [field for field in required_fields if field not in output]
428
  return len(missing) == 0, missing
429
 
430
 
431
- # =============================================================================
432
- # DATA NORMALIZATION
433
- # =============================================================================
434
 
435
  def normalize_antibiotic_name(name: str) -> str:
436
- """
437
- Normalize antibiotic name to standard format.
438
- """
439
- # Common name mappings
440
  mappings = {
441
  "amox": "amoxicillin",
442
  "amox/clav": "amoxicillin-clavulanate",
@@ -459,18 +334,11 @@ def normalize_antibiotic_name(name: str) -> str:
459
  "cefepime": "cefepime",
460
  "maxipime": "cefepime",
461
  }
462
-
463
- normalized = name.lower().strip()
464
- return mappings.get(normalized, normalized)
465
 
466
 
467
  def normalize_organism_name(name: str) -> str:
468
- """
469
- Normalize organism name to standard format.
470
- """
471
- name = name.strip()
472
-
473
- # Common abbreviations
474
  abbreviations = {
475
  "e. coli": "Escherichia coli",
476
  "e.coli": "Escherichia coli",
@@ -485,21 +353,4 @@ def normalize_organism_name(name: str) -> str:
485
  "enterococcus": "Enterococcus species",
486
  "vre": "Enterococcus (VRE)",
487
  }
488
-
489
- lower_name = name.lower()
490
- return abbreviations.get(lower_name, name)
491
-
492
-
493
- __all__ = [
494
- "calculate_crcl",
495
- "calculate_ibw",
496
- "calculate_adjusted_bw",
497
- "get_renal_dose_category",
498
- "calculate_mic_trend",
499
- "detect_mic_creep",
500
- "format_prescription_card",
501
- "safe_json_parse",
502
- "validate_agent_output",
503
- "normalize_antibiotic_name",
504
- "normalize_organism_name",
505
- ]
 
1
  """
2
+ Utility functions for clinical calculations and data parsing.
3
 
4
+ - Creatinine Clearance (CrCl) via Cockcroft-Gault
 
5
  - MIC trend analysis and creep detection
6
  - Prescription card formatter
7
+ - JSON parsing and data normalization helpers
8
  """
9
 
 
 
10
  import json
11
  import math
12
+ import re
13
  from typing import Any, Dict, List, Literal, Optional, Tuple
14
 
15
 
16
+ # --- CrCl calculator ---
 
 
17
 
18
  def calculate_crcl(
19
  age_years: float,
 
24
  height_cm: Optional[float] = None,
25
  ) -> float:
26
  """
27
+ Cockcroft-Gault equation.
 
 
 
28
 
29
+ CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
 
 
 
 
 
 
30
 
31
+ When use_ibw=True and height is given, uses Ideal Body Weight.
32
+ For obese patients (actual > 1.3 × IBW), switches to Adjusted Body Weight.
33
+ Returns CrCl in mL/min.
34
  """
35
  if serum_creatinine_mg_dl <= 0:
36
  raise ValueError("Serum creatinine must be positive")
 
37
  if age_years <= 0 or weight_kg <= 0:
38
  raise ValueError("Age and weight must be positive")
39
 
 
40
  weight = weight_kg
41
  if use_ibw and height_cm:
42
+ ibw = calculate_ibw(height_cm, sex)
43
+ weight = calculate_adjusted_bw(ibw, weight_kg) if weight_kg > ibw * 1.3 else ibw
 
 
44
 
 
45
  crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
 
 
46
  if sex == "female":
47
  crcl *= 0.85
48
 
 
51
 
52
  def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
53
  """
54
+ Devine formula for Ideal Body Weight.
 
 
 
 
55
 
56
+ Male: 50 kg + 2.3 kg per inch over 5 feet
57
+ Female: 45.5 kg + 2.3 kg per inch over 5 feet
58
  """
59
+ height_over_60_inches = max(0, height_cm / 2.54 - 60)
60
+ base = 50 if sex == "male" else 45.5
61
+ return round(base + 2.3 * height_over_60_inches, 1)
 
 
 
 
 
 
62
 
63
 
64
  def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
65
  """
66
+ Adjusted Body Weight for obese patients.
67
 
68
+ AdjBW = IBW + 0.4 × (Actual - IBW)
69
  """
70
  return round(ibw + 0.4 * (actual_weight - ibw), 1)
71
 
72
 
73
  def get_renal_dose_category(crcl: float) -> str:
74
+ """Map CrCl value to a dosing category string."""
 
 
 
 
 
75
  if crcl >= 90:
76
  return "normal"
77
  elif crcl >= 60:
 
84
  return "esrd"
85
 
86
 
87
+ # --- MIC trend analysis ---
 
 
88
 
89
  def calculate_mic_trend(
90
  mic_values: List[Dict[str, Any]],
 
92
  resistant_breakpoint: Optional[float] = None,
93
  ) -> Dict[str, Any]:
94
  """
95
+ Analyze a list of MIC readings over time.
96
 
97
+ Requires at least 2 readings. Uses linear regression slope for trend
98
+ direction when >= 3 points are available; falls back to ratio comparison
99
+ for exactly 2 points.
 
 
 
 
100
  """
101
  if len(mic_values) < 2:
102
  return {
 
105
  "alert": "Need at least 2 MIC values for trend analysis",
106
  }
107
 
 
108
  mics = [float(v["mic_value"]) for v in mic_values]
 
109
  baseline_mic = mics[0]
110
  current_mic = mics[-1]
111
+ fold_change = (current_mic / baseline_mic) if baseline_mic > 0 else float("inf")
112
 
 
 
 
 
 
 
 
113
  if len(mics) >= 3:
 
114
  n = len(mics)
115
  x_mean = (n - 1) / 2
116
  y_mean = sum(mics) / n
117
  numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
118
  denominator = sum((i - x_mean) ** 2 for i in range(n))
119
  slope = numerator / denominator if denominator != 0 else 0
120
+ trend = "increasing" if slope > 0.5 else "decreasing" if slope < -0.5 else "stable"
 
 
 
 
 
 
121
  else:
122
+ trend = "increasing" if current_mic > baseline_mic * 1.5 else "decreasing" if current_mic < baseline_mic * 0.67 else "stable"
123
+
124
+ # Fold change per time step (geometric rate of change)
 
 
 
 
 
125
  velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
126
 
 
127
  risk_level, alert = _assess_mic_risk(
128
  current_mic, baseline_mic, fold_change, trend,
129
+ susceptible_breakpoint, resistant_breakpoint,
130
  )
131
 
132
  return {
 
150
  r_breakpoint: Optional[float],
151
  ) -> Tuple[str, str]:
152
  """
153
+ Assign a risk level (LOW/MODERATE/HIGH/CRITICAL) based on breakpoints and fold change.
154
 
155
+ Prefers breakpoint-based assessment when breakpoints are available.
156
+ Falls back to fold-change thresholds otherwise.
157
  """
 
158
  if s_breakpoint is not None and r_breakpoint is not None:
159
  margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
160
 
161
  if current_mic > r_breakpoint:
162
  return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
 
163
  if current_mic > s_breakpoint:
164
  return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
 
165
  if margin < 2:
166
  if trend == "increasing":
167
  return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
168
+ return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
 
 
169
  if margin < 4:
170
  if trend == "increasing":
171
  return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
172
+ return "LOW", "MIC stable with adequate margin to breakpoint."
 
 
173
  return "LOW", "MIC well below breakpoint with good safety margin."
174
 
175
+ # No breakpoints use fold change thresholds from EUCAST MIC creep criteria
176
  if fold_change >= 8:
177
  return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
 
178
  if fold_change >= 4:
179
  return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
 
180
  if fold_change >= 2:
181
  if trend == "increasing":
182
  return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
183
+ return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
 
 
184
  if trend == "increasing":
185
  return "MODERATE", "MIC showing upward trend. Continue monitoring."
 
186
  return "LOW", "MIC stable or decreasing. Current therapy appropriate."
187
 
188
 
 
195
  """
196
  Detect MIC creep for a specific organism-antibiotic pair.
197
 
198
+ Augments calculate_mic_trend with a time-to-resistance estimate
199
+ when the MIC is rising and a susceptible breakpoint is available.
 
 
 
 
 
 
200
  """
201
+ result = calculate_mic_trend(
202
  mic_history,
203
  susceptible_breakpoint=breakpoints.get("susceptible"),
204
  resistant_breakpoint=breakpoints.get("resistant"),
205
  )
206
 
207
+ result["organism"] = organism
208
+ result["antibiotic"] = antibiotic
209
+ result["breakpoint_susceptible"] = breakpoints.get("susceptible")
210
+ result["breakpoint_resistant"] = breakpoints.get("resistant")
 
211
 
212
+ # Estimate how many more time-points until MIC reaches the susceptible breakpoint
213
+ if result["trend"] == "increasing" and result["velocity"] > 1.0:
214
+ current = result["current_mic"]
215
  s_bp = breakpoints.get("susceptible")
216
  if s_bp and current < s_bp:
 
217
  doublings_needed = math.log2(s_bp / current) if current > 0 else 0
218
+ log_velocity = math.log(result["velocity"]) / math.log(2)
219
+ if log_velocity > 0:
220
+ result["estimated_readings_to_resistance"] = round(doublings_needed / log_velocity, 1)
 
 
 
221
 
222
+ return result
223
 
224
 
225
+ # --- Prescription formatter ---
 
 
226
 
227
  def format_prescription_card(recommendation: Dict[str, Any]) -> str:
228
+ """Format a recommendation dict as a plain-text prescription card."""
 
 
 
 
 
 
 
 
229
  lines = []
230
  lines.append("=" * 50)
231
  lines.append("ANTIBIOTIC PRESCRIPTION")
 
242
  if primary.get("aware_category"):
243
  lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
244
 
 
245
  adjustments = recommendation.get("dose_adjustments", {})
246
  if adjustments.get("renal") and adjustments["renal"] != "None needed":
247
  lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
248
  if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
249
  lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
250
 
 
251
  alerts = recommendation.get("safety_alerts", [])
252
  if alerts:
253
  lines.append("\n" + "-" * 50)
 
257
  marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
258
  lines.append(f" {marker} {alert.get('message', '')}")
259
 
 
260
  monitoring = recommendation.get("monitoring_parameters", [])
261
  if monitoring:
262
  lines.append("\n" + "-" * 50)
 
264
  for param in monitoring:
265
  lines.append(f" - {param}")
266
 
 
267
  if recommendation.get("rationale"):
268
  lines.append("\n" + "-" * 50)
269
  lines.append("RATIONALE:")
270
  lines.append(f" {recommendation['rationale']}")
271
 
272
  lines.append("\n" + "=" * 50)
 
273
  return "\n".join(lines)
274
 
275
 
276
+ # --- JSON parsing ---
 
 
277
 
278
  def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
279
  """
280
+ Extract and parse the first JSON object from a string.
281
 
282
+ Handles model output that may wrap JSON in markdown code fences.
283
+ Returns None if no valid JSON is found.
284
  """
285
  if not text:
286
  return None
287
 
 
288
  try:
289
  return json.loads(text)
290
  except json.JSONDecodeError:
291
  pass
292
 
293
+ for pattern in [r"```json\s*\n?(.*?)\n?```", r"```\s*\n?(.*?)\n?```", r"\{[\s\S]*\}"]:
 
 
 
 
 
 
 
 
 
294
  match = re.search(pattern, text, re.DOTALL)
295
  if match:
296
  try:
 
303
 
304
 
305
  def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
306
+ """Return (is_valid, missing_fields) for an agent output dict."""
307
+ missing = [f for f in required_fields if f not in output]
 
 
 
 
 
 
 
 
 
308
  return len(missing) == 0, missing
309
 
310
 
311
+ # --- Name normalization ---
 
 
312
 
313
  def normalize_antibiotic_name(name: str) -> str:
314
+ """Map common abbreviations and brand names to standard antibiotic names."""
 
 
 
315
  mappings = {
316
  "amox": "amoxicillin",
317
  "amox/clav": "amoxicillin-clavulanate",
 
334
  "cefepime": "cefepime",
335
  "maxipime": "cefepime",
336
  }
337
+ return mappings.get(name.lower().strip(), name.lower().strip())
 
 
338
 
339
 
340
  def normalize_organism_name(name: str) -> str:
341
+ """Map common abbreviations to full organism names."""
 
 
 
 
 
342
  abbreviations = {
343
  "e. coli": "Escherichia coli",
344
  "e.coli": "Escherichia coli",
 
353
  "enterococcus": "Enterococcus species",
354
  "vre": "Enterococcus (VRE)",
355
  }
356
+ return abbreviations.get(name.strip().lower(), name.strip())