File size: 5,595 Bytes
91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ba2715f 91b591f ba2715f 18c0556 91b591f ba2715f 91b591f ba2715f 18c0556 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f ad9e267 91b591f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
LangGraph orchestrator for the infection lifecycle workflow.
Stage 1 (empirical - no lab results):
Intake Historian → Clinical Pharmacologist
Stage 2 (targeted - lab results available):
Intake Historian → Vision Specialist → [Trend Analyst →] Clinical Pharmacologist
"""
import logging
from typing import Literal
from langgraph.graph import StateGraph, END
from .agents import (
run_intake_historian,
run_vision_specialist,
run_trend_analyst,
run_clinical_pharmacologist,
)
from .state import InfectionState
logger = logging.getLogger(__name__)
def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
"""Route to Vision Specialist if we have lab text to parse; otherwise go straight to pharmacologist."""
if state.get("stage") == "targeted" and state.get("route_to_vision"):
logger.info("Graph: routing to Vision Specialist (targeted path)")
return "vision_specialist"
logger.info("Graph: routing to Clinical Pharmacologist (empirical path)")
return "clinical_pharmacologist"
def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
"""Route to Trend Analyst if Vision Specialist extracted MIC values."""
if state.get("route_to_trend_analyst"):
logger.info("Graph: routing to Trend Analyst")
return "trend_analyst"
logger.info("Graph: skipping Trend Analyst (no MIC data)")
return "clinical_pharmacologist"
def build_infection_graph() -> StateGraph:
"""Build and return the compiled LangGraph for the infection pipeline."""
graph = StateGraph(InfectionState)
graph.add_node("intake_historian", run_intake_historian)
graph.add_node("vision_specialist", run_vision_specialist)
graph.add_node("trend_analyst", run_trend_analyst)
graph.add_node("clinical_pharmacologist", run_clinical_pharmacologist)
graph.set_entry_point("intake_historian")
graph.add_conditional_edges(
"intake_historian",
route_after_intake,
{"vision_specialist": "vision_specialist", "clinical_pharmacologist": "clinical_pharmacologist"},
)
graph.add_conditional_edges(
"vision_specialist",
route_after_vision,
{"trend_analyst": "trend_analyst", "clinical_pharmacologist": "clinical_pharmacologist"},
)
graph.add_edge("trend_analyst", "clinical_pharmacologist")
graph.add_edge("clinical_pharmacologist", END)
return graph
def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> InfectionState:
"""
Run the full infection pipeline and return the final state.
Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
Without it, only the empirical (Stage 1) pathway runs.
"""
labs_image_bytes: bytes | None = patient_data.get("labs_image_bytes")
has_lab_input = bool(labs_raw_text or labs_image_bytes)
initial_state: InfectionState = {
"age_years": patient_data.get("age_years"),
"weight_kg": patient_data.get("weight_kg"),
"height_cm": patient_data.get("height_cm"),
"sex": patient_data.get("sex"),
"serum_creatinine_mg_dl": patient_data.get("serum_creatinine_mg_dl"),
"medications": patient_data.get("medications", []),
"allergies": patient_data.get("allergies", []),
"comorbidities": patient_data.get("comorbidities", []),
"infection_site": patient_data.get("infection_site"),
"suspected_source": patient_data.get("suspected_source"),
"country_or_region": patient_data.get("country_or_region"),
"vitals": patient_data.get("vitals", {}),
"stage": "targeted" if has_lab_input else "empirical",
"errors": [],
"safety_warnings": [],
}
if labs_raw_text:
initial_state["labs_raw_text"] = labs_raw_text
if labs_image_bytes:
initial_state["labs_image_bytes"] = labs_image_bytes
logger.info(f"Starting pipeline (stage: {initial_state['stage']}, lab_text={bool(labs_raw_text)}, lab_image={bool(labs_image_bytes)})")
logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
try:
compiled = build_infection_graph().compile()
logger.info("Graph compiled successfully")
final_state = compiled.invoke(initial_state)
logger.info("Pipeline complete")
return final_state
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
initial_state["errors"].append(f"Pipeline error: {str(e)}")
return initial_state
def run_empirical_pipeline(patient_data: dict) -> InfectionState:
"""Shorthand for run_pipeline without lab data (Stage 1)."""
return run_pipeline(patient_data)
def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
"""Shorthand for run_pipeline with lab data (Stage 2)."""
return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
def get_graph_mermaid() -> str:
"""Return a Mermaid diagram of the graph (for documentation and debugging)."""
try:
return build_infection_graph().compile().get_graph().draw_mermaid()
except Exception:
return """
graph TD
A[intake_historian] --> B{route_after_intake}
B -->|targeted + lab data| C[vision_specialist]
B -->|empirical| E[clinical_pharmacologist]
C --> D{route_after_vision}
D -->|has MIC data| F[trend_analyst]
D -->|no MIC data| E
F --> E
E --> G[END]
"""
|