| | """ |
| | Four-agent pipeline for the infection lifecycle workflow. |
| | |
| | Agent 1 - Intake Historian: parse patient data, calculate CrCl, identify AMR risk factors |
| | Agent 2 - Vision Specialist: extract organisms and MIC values from lab reports |
| | Agent 3 - Trend Analyst: detect MIC creep and resistance velocity |
| | Agent 4 - Clinical Pharmacologist: generate final antibiotic recommendation with safety checks |
| | """ |
| |
|
| | import json |
| | import logging |
| | from typing import Optional |
| |
|
| | from .config import get_settings |
| | from .loader import run_inference, run_inference_with_image, TextModelName |
| | from .prompts import ( |
| | INTAKE_HISTORIAN_SYSTEM, |
| | INTAKE_HISTORIAN_PROMPT, |
| | VISION_SPECIALIST_SYSTEM, |
| | VISION_SPECIALIST_PROMPT, |
| | TREND_ANALYST_SYSTEM, |
| | TREND_ANALYST_PROMPT, |
| | CLINICAL_PHARMACOLOGIST_SYSTEM, |
| | CLINICAL_PHARMACOLOGIST_PROMPT, |
| | TXGEMMA_SAFETY_PROMPT, |
| | ) |
| | from .rag import get_context_for_agent |
| | from .state import InfectionState |
| | from .utils import ( |
| | calculate_crcl, |
| | get_renal_dose_category, |
| | safe_json_parse, |
| | normalize_organism_name, |
| | normalize_antibiotic_name, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def run_intake_historian(state: InfectionState) -> InfectionState: |
| | """Parse patient data, calculate CrCl, identify MDR risk factors, and set the treatment stage.""" |
| | logger.info("Running Intake Historian agent...") |
| |
|
| | crcl = None |
| | if all([state.get("age_years"), state.get("weight_kg"), state.get("serum_creatinine_mg_dl"), state.get("sex")]): |
| | try: |
| | crcl = calculate_crcl( |
| | age_years=state["age_years"], |
| | weight_kg=state["weight_kg"], |
| | serum_creatinine_mg_dl=state["serum_creatinine_mg_dl"], |
| | sex=state["sex"], |
| | use_ibw=True, |
| | height_cm=state.get("height_cm"), |
| | ) |
| | state["creatinine_clearance_ml_min"] = crcl |
| | logger.info(f"Calculated CrCl: {crcl} mL/min") |
| | except Exception as e: |
| | logger.warning(f"Could not calculate CrCl: {e}") |
| | state.setdefault("errors", []).append(f"CrCl calculation error: {e}") |
| |
|
| | patient_data = _format_patient_data(state) |
| | query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}" |
| | rag_context = get_context_for_agent( |
| | agent_name="intake_historian", |
| | query=query, |
| | patient_context={"pathogen_type": state.get("suspected_source")}, |
| | ) |
| |
|
| | site_vitals_str = "\n".join( |
| | f"- {k.replace('_', ' ').title()}: {v}" for k, v in state.get("vitals", {}).items() |
| | ) or "None provided" |
| |
|
| | prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format( |
| | patient_data=patient_data, |
| | medications=', '.join(state.get('medications', [])) or 'None reported', |
| | allergies=', '.join(state.get('allergies', [])) or 'No known allergies', |
| | infection_site=state.get('infection_site', 'Unknown'), |
| | suspected_source=state.get('suspected_source', 'Unknown'), |
| | site_vitals=site_vitals_str, |
| | rag_context=rag_context, |
| | )}" |
| |
|
| | try: |
| | response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=1024, temperature=0.2) |
| | parsed = safe_json_parse(response) |
| | if parsed: |
| | state["intake_notes"] = json.dumps(parsed, indent=2) |
| | if parsed.get("creatinine_clearance_ml_min") and crcl is None: |
| | state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"] |
| | state["stage"] = parsed.get("recommended_stage", "empirical") |
| | else: |
| | state["intake_notes"] = response |
| | state["stage"] = "empirical" |
| |
|
| | |
| | state["route_to_vision"] = bool(state.get("labs_raw_text")) |
| |
|
| | except Exception as e: |
| | logger.error(f"Intake Historian error: {e}") |
| | state.setdefault("errors", []).append(f"Intake Historian error: {e}") |
| | state["intake_notes"] = f"Error: {e}" |
| | state["stage"] = "empirical" |
| |
|
| | logger.info(f"Intake Historian complete. Stage: {state.get('stage')}") |
| | return state |
| |
|
| |
|
| | def run_vision_specialist(state: InfectionState) -> InfectionState: |
| | """Extract pathogen names, MIC values, and S/I/R interpretations from lab report text.""" |
| | logger.info("Running Vision Specialist agent...") |
| |
|
| | labs_raw = state.get("labs_raw_text", "") |
| | labs_image_bytes = state.get("labs_image_bytes") |
| |
|
| | if not labs_raw and not labs_image_bytes: |
| | logger.info("No lab data to process, skipping Vision Specialist") |
| | state["vision_notes"] = "No lab data provided" |
| | state["route_to_trend_analyst"] = False |
| | return state |
| |
|
| | |
| | if labs_image_bytes: |
| | source_format = "image" |
| | language = "Auto-detected" |
| | report_content = "See attached image — extract all lab data visible in the image." |
| | else: |
| | source_format = "text" |
| | language = "English (assumed)" |
| | report_content = labs_raw |
| |
|
| | rag_context = get_context_for_agent( |
| | agent_name="vision_specialist", |
| | query="culture sensitivity susceptibility interpretation", |
| | patient_context={}, |
| | ) |
| |
|
| | prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format( |
| | report_content=report_content, |
| | source_format=source_format, |
| | language=language, |
| | )}" |
| |
|
| | try: |
| | if labs_image_bytes: |
| | from io import BytesIO |
| | from PIL import Image as PILImage |
| | image = PILImage.open(BytesIO(labs_image_bytes)).convert("RGB") |
| | logger.info(f"Running vision inference on uploaded image ({image.size})") |
| | response = run_inference_with_image( |
| | prompt=prompt, image=image, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1 |
| | ) |
| | else: |
| | response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1) |
| | parsed = safe_json_parse(response) |
| | if parsed: |
| | state["vision_notes"] = json.dumps(parsed, indent=2) |
| |
|
| | organisms = parsed.get("identified_organisms", []) |
| | susceptibility = parsed.get("susceptibility_results", []) |
| |
|
| | mic_data = [ |
| | { |
| | "organism": normalize_organism_name(r.get("organism", "")), |
| | "antibiotic": normalize_antibiotic_name(r.get("antibiotic", "")), |
| | "mic_value": str(r.get("mic_value", "")), |
| | "mic_unit": r.get("mic_unit", "mg/L"), |
| | "interpretation": r.get("interpretation"), |
| | } |
| | for r in susceptibility |
| | ] |
| |
|
| | state["mic_data"] = mic_data |
| | state["labs_parsed"] = [ |
| | { |
| | "name": org.get("organism_name", "Unknown"), |
| | "value": org.get("colony_count", ""), |
| | "flag": "pathogen" if org.get("significance") == "pathogen" else None, |
| | } |
| | for org in organisms |
| | ] |
| | state["route_to_trend_analyst"] = len(mic_data) > 0 |
| |
|
| | critical = parsed.get("critical_findings", []) |
| | if critical: |
| | state.setdefault("safety_warnings", []).extend(critical) |
| | else: |
| | state["vision_notes"] = response |
| | state["route_to_trend_analyst"] = False |
| |
|
| | except Exception as e: |
| | logger.error(f"Vision Specialist error: {e}") |
| | state.setdefault("errors", []).append(f"Vision Specialist error: {e}") |
| | state["vision_notes"] = f"Error: {e}" |
| | state["route_to_trend_analyst"] = False |
| |
|
| | logger.info(f"Vision Specialist complete. MIC data points: {len(state.get('mic_data', []))}") |
| | return state |
| |
|
| |
|
| | def run_trend_analyst(state: InfectionState) -> InfectionState: |
| | """Analyze MIC trends per organism-antibiotic pair and flag high-risk creep.""" |
| | logger.info("Running Trend Analyst agent...") |
| |
|
| | mic_data = state.get("mic_data", []) |
| | if not mic_data: |
| | logger.info("No MIC data to analyze, skipping Trend Analyst") |
| | state["trend_notes"] = "No MIC data available for trend analysis" |
| | return state |
| |
|
| | trend_results = [] |
| |
|
| | for mic in mic_data: |
| | organism = mic.get("organism", "Unknown") |
| | antibiotic = mic.get("antibiotic", "Unknown") |
| |
|
| | rag_context = get_context_for_agent( |
| | agent_name="trend_analyst", |
| | query=f"breakpoint {organism} {antibiotic}", |
| | patient_context={ |
| | "organism": organism, |
| | "antibiotic": antibiotic, |
| | "region": state.get("country_or_region"), |
| | }, |
| | ) |
| |
|
| | |
| | mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}] |
| |
|
| | prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format( |
| | organism=organism, |
| | antibiotic=antibiotic, |
| | mic_history=json.dumps(mic_history, indent=2), |
| | breakpoint_data=rag_context, |
| | resistance_context='Regional data not available', |
| | )}" |
| |
|
| | try: |
| | |
| | response = run_inference( |
| | prompt=prompt, |
| | model_name="medgemma_27b", |
| | max_new_tokens=1024, |
| | temperature=0.2, |
| | ) |
| | parsed = safe_json_parse(response) |
| | if parsed: |
| | trend_results.append(parsed) |
| | risk_level = parsed.get("risk_level", "LOW") |
| | if risk_level in ["HIGH", "CRITICAL"]: |
| | warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}" |
| | state.setdefault("safety_warnings", []).append(warning) |
| | else: |
| | trend_results.append({"raw_response": response}) |
| |
|
| | except Exception as e: |
| | logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}") |
| | trend_results.append({"error": str(e)}) |
| |
|
| | state["trend_notes"] = json.dumps(trend_results, indent=2) |
| |
|
| | high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"]) |
| | state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}" |
| |
|
| | logger.info(f"Trend Analyst complete. {state['mic_trend_summary']}") |
| | return state |
| |
|
| |
|
| | def run_clinical_pharmacologist(state: InfectionState) -> InfectionState: |
| | """Synthesize all agent outputs into a final antibiotic recommendation with safety checks.""" |
| | logger.info("Running Clinical Pharmacologist agent...") |
| |
|
| | intake_summary = state.get("intake_notes", "No intake data") |
| | lab_results = state.get("vision_notes", "No lab data") |
| | trend_analysis = state.get("trend_notes", "No trend data") |
| |
|
| | query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation" |
| | rag_context = get_context_for_agent( |
| | agent_name="clinical_pharmacologist", |
| | query=query, |
| | patient_context={"proposed_antibiotic": None}, |
| | ) |
| |
|
| | site_vitals_str = "\n".join( |
| | f"- {k.replace('_', ' ').title()}: {v}" for k, v in state.get("vitals", {}).items() |
| | ) or "None provided" |
| |
|
| | prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format( |
| | intake_summary=intake_summary, |
| | lab_results=lab_results, |
| | trend_analysis=trend_analysis, |
| | age=state.get('age_years', 'Unknown'), |
| | weight=state.get('weight_kg', 'Unknown'), |
| | crcl=state.get('creatinine_clearance_ml_min', 'Unknown'), |
| | allergies=', '.join(state.get('allergies', [])) or 'No known allergies', |
| | current_medications=', '.join(state.get('medications', [])) or 'None reported', |
| | infection_site=state.get('infection_site', 'Unknown'), |
| | suspected_source=state.get('suspected_source', 'Unknown'), |
| | severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown', |
| | site_vitals=site_vitals_str, |
| | rag_context=rag_context, |
| | )}" |
| |
|
| | try: |
| | response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.2) |
| | parsed = safe_json_parse(response) |
| | if parsed: |
| | state["pharmacology_notes"] = json.dumps(parsed, indent=2) |
| |
|
| | primary = parsed.get("primary_recommendation", {}) |
| | recommendation = { |
| | "primary_antibiotic": primary.get("antibiotic"), |
| | "dose": primary.get("dose"), |
| | "route": primary.get("route"), |
| | "frequency": primary.get("frequency"), |
| | "duration": primary.get("duration"), |
| | "rationale": parsed.get("rationale"), |
| | "references": parsed.get("guideline_references", []), |
| | "safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])], |
| | } |
| |
|
| | alt = parsed.get("alternative_recommendation", {}) |
| | if alt.get("antibiotic"): |
| | recommendation["backup_antibiotic"] = alt.get("antibiotic") |
| |
|
| | state["recommendation"] = recommendation |
| |
|
| | for alert in parsed.get("safety_alerts", []): |
| | if alert.get("level") in ["WARNING", "CRITICAL"]: |
| | state.setdefault("safety_warnings", []).append(alert.get("message")) |
| |
|
| | if primary.get("antibiotic"): |
| | safety_result = _run_txgemma_safety_check( |
| | antibiotic=primary.get("antibiotic"), |
| | dose=primary.get("dose"), |
| | route=primary.get("route"), |
| | duration=primary.get("duration"), |
| | age=state.get("age_years"), |
| | crcl=state.get("creatinine_clearance_ml_min"), |
| | medications=state.get("medications", []), |
| | ) |
| | if safety_result: |
| | state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}") |
| | else: |
| | state["pharmacology_notes"] = response |
| | state["recommendation"] = {"rationale": response} |
| |
|
| | except Exception as e: |
| | logger.error(f"Clinical Pharmacologist error: {e}") |
| | state.setdefault("errors", []).append(f"Clinical Pharmacologist error: {e}") |
| | state["pharmacology_notes"] = f"Error: {e}" |
| |
|
| | logger.info("Clinical Pharmacologist complete.") |
| | return state |
| |
|
| |
|
| | def _format_patient_data(state: InfectionState) -> str: |
| | """Format patient fields from state into a readable string for prompt injection.""" |
| | lines = [] |
| |
|
| | if state.get("patient_id"): |
| | lines.append(f"Patient ID: {state['patient_id']}") |
| |
|
| | demographics = [] |
| | if state.get("age_years"): |
| | demographics.append(f"{state['age_years']} years old") |
| | if state.get("sex"): |
| | demographics.append(state["sex"]) |
| | if demographics: |
| | lines.append(f"Demographics: {', '.join(demographics)}") |
| |
|
| | if state.get("weight_kg"): |
| | lines.append(f"Weight: {state['weight_kg']} kg") |
| | if state.get("height_cm"): |
| | lines.append(f"Height: {state['height_cm']} cm") |
| |
|
| | if state.get("serum_creatinine_mg_dl"): |
| | lines.append(f"Serum Creatinine: {state['serum_creatinine_mg_dl']} mg/dL") |
| | if state.get("creatinine_clearance_ml_min"): |
| | crcl = state["creatinine_clearance_ml_min"] |
| | category = get_renal_dose_category(crcl) |
| | lines.append(f"CrCl: {crcl} mL/min ({category})") |
| |
|
| | if state.get("comorbidities"): |
| | lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}") |
| |
|
| | if state.get("vitals"): |
| | lines.append("Site-Specific Assessment:") |
| | for k, v in state["vitals"].items(): |
| | label = k.replace("_", " ").title() |
| | lines.append(f" - {label}: {v}") |
| |
|
| | return "\n".join(lines) if lines else "No patient data available" |
| |
|
| |
|
| | def _run_txgemma_safety_check( |
| | antibiotic: str, |
| | dose: Optional[str], |
| | route: Optional[str], |
| | duration: Optional[str], |
| | age: Optional[float], |
| | crcl: Optional[float], |
| | medications: list, |
| | ) -> Optional[str]: |
| | """Run a supplementary TxGemma toxicology check on the proposed prescription.""" |
| | try: |
| | prompt = TXGEMMA_SAFETY_PROMPT.format( |
| | antibiotic=antibiotic, |
| | dose=dose or "Not specified", |
| | route=route or "Not specified", |
| | duration=duration or "Not specified", |
| | age=age or "Unknown", |
| | crcl=crcl or "Unknown", |
| | medications=", ".join(medications) if medications else "None", |
| | ) |
| | |
| | return run_inference(prompt=prompt, model_name="txgemma_9b", max_new_tokens=256, temperature=0.1) |
| | except Exception as e: |
| | logger.warning(f"TxGemma safety check failed: {e}") |
| | return None |
| |
|
| |
|
| | AGENTS = { |
| | "intake_historian": run_intake_historian, |
| | "vision_specialist": run_vision_specialist, |
| | "trend_analyst": run_trend_analyst, |
| | "clinical_pharmacologist": run_clinical_pharmacologist, |
| | } |
| |
|
| |
|
| | def run_agent(agent_name: str, state: InfectionState) -> InfectionState: |
| | """Dispatch to a named agent.""" |
| | if agent_name not in AGENTS: |
| | raise ValueError(f"Unknown agent: {agent_name}") |
| | return AGENTS[agent_name](state) |
| |
|