AMR-Guard / src /graph.py
ghitaben's picture
Enhance patient analysis form with dynamic site-specific fields and support for lab image uploads
ba2715f
"""
LangGraph orchestrator for the infection lifecycle workflow.
Stage 1 (empirical - no lab results):
Intake Historian → Clinical Pharmacologist
Stage 2 (targeted - lab results available):
Intake Historian → Vision Specialist → [Trend Analyst →] Clinical Pharmacologist
"""
import logging
from typing import Literal
from langgraph.graph import StateGraph, END
from .agents import (
run_intake_historian,
run_vision_specialist,
run_trend_analyst,
run_clinical_pharmacologist,
)
from .state import InfectionState
logger = logging.getLogger(__name__)
def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
"""Route to Vision Specialist if we have lab text to parse; otherwise go straight to pharmacologist."""
if state.get("stage") == "targeted" and state.get("route_to_vision"):
logger.info("Graph: routing to Vision Specialist (targeted path)")
return "vision_specialist"
logger.info("Graph: routing to Clinical Pharmacologist (empirical path)")
return "clinical_pharmacologist"
def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
"""Route to Trend Analyst if Vision Specialist extracted MIC values."""
if state.get("route_to_trend_analyst"):
logger.info("Graph: routing to Trend Analyst")
return "trend_analyst"
logger.info("Graph: skipping Trend Analyst (no MIC data)")
return "clinical_pharmacologist"
def build_infection_graph() -> StateGraph:
"""Build and return the compiled LangGraph for the infection pipeline."""
graph = StateGraph(InfectionState)
graph.add_node("intake_historian", run_intake_historian)
graph.add_node("vision_specialist", run_vision_specialist)
graph.add_node("trend_analyst", run_trend_analyst)
graph.add_node("clinical_pharmacologist", run_clinical_pharmacologist)
graph.set_entry_point("intake_historian")
graph.add_conditional_edges(
"intake_historian",
route_after_intake,
{"vision_specialist": "vision_specialist", "clinical_pharmacologist": "clinical_pharmacologist"},
)
graph.add_conditional_edges(
"vision_specialist",
route_after_vision,
{"trend_analyst": "trend_analyst", "clinical_pharmacologist": "clinical_pharmacologist"},
)
graph.add_edge("trend_analyst", "clinical_pharmacologist")
graph.add_edge("clinical_pharmacologist", END)
return graph
def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> InfectionState:
"""
Run the full infection pipeline and return the final state.
Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
Without it, only the empirical (Stage 1) pathway runs.
"""
labs_image_bytes: bytes | None = patient_data.get("labs_image_bytes")
has_lab_input = bool(labs_raw_text or labs_image_bytes)
initial_state: InfectionState = {
"age_years": patient_data.get("age_years"),
"weight_kg": patient_data.get("weight_kg"),
"height_cm": patient_data.get("height_cm"),
"sex": patient_data.get("sex"),
"serum_creatinine_mg_dl": patient_data.get("serum_creatinine_mg_dl"),
"medications": patient_data.get("medications", []),
"allergies": patient_data.get("allergies", []),
"comorbidities": patient_data.get("comorbidities", []),
"infection_site": patient_data.get("infection_site"),
"suspected_source": patient_data.get("suspected_source"),
"country_or_region": patient_data.get("country_or_region"),
"vitals": patient_data.get("vitals", {}),
"stage": "targeted" if has_lab_input else "empirical",
"errors": [],
"safety_warnings": [],
}
if labs_raw_text:
initial_state["labs_raw_text"] = labs_raw_text
if labs_image_bytes:
initial_state["labs_image_bytes"] = labs_image_bytes
logger.info(f"Starting pipeline (stage: {initial_state['stage']}, lab_text={bool(labs_raw_text)}, lab_image={bool(labs_image_bytes)})")
logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
try:
compiled = build_infection_graph().compile()
logger.info("Graph compiled successfully")
final_state = compiled.invoke(initial_state)
logger.info("Pipeline complete")
return final_state
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
initial_state["errors"].append(f"Pipeline error: {str(e)}")
return initial_state
def run_empirical_pipeline(patient_data: dict) -> InfectionState:
"""Shorthand for run_pipeline without lab data (Stage 1)."""
return run_pipeline(patient_data)
def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
"""Shorthand for run_pipeline with lab data (Stage 2)."""
return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
def get_graph_mermaid() -> str:
"""Return a Mermaid diagram of the graph (for documentation and debugging)."""
try:
return build_infection_graph().compile().get_graph().draw_mermaid()
except Exception:
return """
graph TD
A[intake_historian] --> B{route_after_intake}
B -->|targeted + lab data| C[vision_specialist]
B -->|empirical| E[clinical_pharmacologist]
C --> D{route_after_vision}
D -->|has MIC data| F[trend_analyst]
D -->|no MIC data| E
F --> E
E --> G[END]
"""