| """ | |
| LangGraph orchestrator for the infection lifecycle workflow. | |
| Stage 1 (empirical - no lab results): | |
| Intake Historian → Clinical Pharmacologist | |
| Stage 2 (targeted - lab results available): | |
| Intake Historian → Vision Specialist → [Trend Analyst →] Clinical Pharmacologist | |
| """ | |
| import logging | |
| from typing import Literal | |
| from langgraph.graph import StateGraph, END | |
| from .agents import ( | |
| run_intake_historian, | |
| run_vision_specialist, | |
| run_trend_analyst, | |
| run_clinical_pharmacologist, | |
| ) | |
| from .state import InfectionState | |
| logger = logging.getLogger(__name__) | |
| def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]: | |
| """Route to Vision Specialist if we have lab text to parse; otherwise go straight to pharmacologist.""" | |
| if state.get("stage") == "targeted" and state.get("route_to_vision"): | |
| logger.info("Graph: routing to Vision Specialist (targeted path)") | |
| return "vision_specialist" | |
| logger.info("Graph: routing to Clinical Pharmacologist (empirical path)") | |
| return "clinical_pharmacologist" | |
| def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]: | |
| """Route to Trend Analyst if Vision Specialist extracted MIC values.""" | |
| if state.get("route_to_trend_analyst"): | |
| logger.info("Graph: routing to Trend Analyst") | |
| return "trend_analyst" | |
| logger.info("Graph: skipping Trend Analyst (no MIC data)") | |
| return "clinical_pharmacologist" | |
| def build_infection_graph() -> StateGraph: | |
| """Build and return the compiled LangGraph for the infection pipeline.""" | |
| graph = StateGraph(InfectionState) | |
| graph.add_node("intake_historian", run_intake_historian) | |
| graph.add_node("vision_specialist", run_vision_specialist) | |
| graph.add_node("trend_analyst", run_trend_analyst) | |
| graph.add_node("clinical_pharmacologist", run_clinical_pharmacologist) | |
| graph.set_entry_point("intake_historian") | |
| graph.add_conditional_edges( | |
| "intake_historian", | |
| route_after_intake, | |
| {"vision_specialist": "vision_specialist", "clinical_pharmacologist": "clinical_pharmacologist"}, | |
| ) | |
| graph.add_conditional_edges( | |
| "vision_specialist", | |
| route_after_vision, | |
| {"trend_analyst": "trend_analyst", "clinical_pharmacologist": "clinical_pharmacologist"}, | |
| ) | |
| graph.add_edge("trend_analyst", "clinical_pharmacologist") | |
| graph.add_edge("clinical_pharmacologist", END) | |
| return graph | |
| def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> InfectionState: | |
| """ | |
| Run the full infection pipeline and return the final state. | |
| Pass labs_raw_text to trigger the targeted (Stage 2) pathway. | |
| Without it, only the empirical (Stage 1) pathway runs. | |
| """ | |
| labs_image_bytes: bytes | None = patient_data.get("labs_image_bytes") | |
| has_lab_input = bool(labs_raw_text or labs_image_bytes) | |
| initial_state: InfectionState = { | |
| "age_years": patient_data.get("age_years"), | |
| "weight_kg": patient_data.get("weight_kg"), | |
| "height_cm": patient_data.get("height_cm"), | |
| "sex": patient_data.get("sex"), | |
| "serum_creatinine_mg_dl": patient_data.get("serum_creatinine_mg_dl"), | |
| "medications": patient_data.get("medications", []), | |
| "allergies": patient_data.get("allergies", []), | |
| "comorbidities": patient_data.get("comorbidities", []), | |
| "infection_site": patient_data.get("infection_site"), | |
| "suspected_source": patient_data.get("suspected_source"), | |
| "country_or_region": patient_data.get("country_or_region"), | |
| "vitals": patient_data.get("vitals", {}), | |
| "stage": "targeted" if has_lab_input else "empirical", | |
| "errors": [], | |
| "safety_warnings": [], | |
| } | |
| if labs_raw_text: | |
| initial_state["labs_raw_text"] = labs_raw_text | |
| if labs_image_bytes: | |
| initial_state["labs_image_bytes"] = labs_image_bytes | |
| logger.info(f"Starting pipeline (stage: {initial_state['stage']}, lab_text={bool(labs_raw_text)}, lab_image={bool(labs_image_bytes)})") | |
| logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}") | |
| try: | |
| compiled = build_infection_graph().compile() | |
| logger.info("Graph compiled successfully") | |
| final_state = compiled.invoke(initial_state) | |
| logger.info("Pipeline complete") | |
| return final_state | |
| except Exception as e: | |
| logger.error(f"Pipeline execution failed: {e}", exc_info=True) | |
| initial_state["errors"].append(f"Pipeline error: {str(e)}") | |
| return initial_state | |
| def run_empirical_pipeline(patient_data: dict) -> InfectionState: | |
| """Shorthand for run_pipeline without lab data (Stage 1).""" | |
| return run_pipeline(patient_data) | |
| def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState: | |
| """Shorthand for run_pipeline with lab data (Stage 2).""" | |
| return run_pipeline(patient_data, labs_raw_text=labs_raw_text) | |
| def get_graph_mermaid() -> str: | |
| """Return a Mermaid diagram of the graph (for documentation and debugging).""" | |
| try: | |
| return build_infection_graph().compile().get_graph().draw_mermaid() | |
| except Exception: | |
| return """ | |
| graph TD | |
| A[intake_historian] --> B{route_after_intake} | |
| B -->|targeted + lab data| C[vision_specialist] | |
| B -->|empirical| E[clinical_pharmacologist] | |
| C --> D{route_after_vision} | |
| D -->|has MIC data| F[trend_analyst] | |
| D -->|no MIC data| E | |
| F --> E | |
| E --> G[END] | |
| """ | |