File size: 5,595 Bytes
91b591f
ad9e267
91b591f
ad9e267
 
91b591f
ad9e267
 
91b591f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad9e267
 
 
91b591f
ad9e267
 
91b591f
 
 
ad9e267
 
 
91b591f
ad9e267
 
91b591f
 
 
ad9e267
91b591f
 
ad9e267
 
 
 
91b591f
 
 
 
 
 
ad9e267
91b591f
 
 
 
ad9e267
91b591f
 
 
 
 
 
 
 
ad9e267
91b591f
ad9e267
91b591f
ad9e267
 
91b591f
ba2715f
 
 
91b591f
 
 
 
 
 
 
 
 
 
 
 
 
ba2715f
18c0556
 
91b591f
 
 
 
ba2715f
 
91b591f
ba2715f
18c0556
 
 
 
 
 
 
 
 
 
 
 
91b591f
 
 
ad9e267
 
91b591f
 
 
ad9e267
91b591f
 
 
 
ad9e267
91b591f
ad9e267
91b591f
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
LangGraph orchestrator for the infection lifecycle workflow.

Stage 1 (empirical - no lab results):
    Intake Historian → Clinical Pharmacologist

Stage 2 (targeted - lab results available):
    Intake Historian → Vision Specialist → [Trend Analyst →] Clinical Pharmacologist
"""

import logging
from typing import Literal

from langgraph.graph import StateGraph, END

from .agents import (
    run_intake_historian,
    run_vision_specialist,
    run_trend_analyst,
    run_clinical_pharmacologist,
)
from .state import InfectionState

logger = logging.getLogger(__name__)


def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
    """Route to Vision Specialist if we have lab text to parse; otherwise go straight to pharmacologist."""
    if state.get("stage") == "targeted" and state.get("route_to_vision"):
        logger.info("Graph: routing to Vision Specialist (targeted path)")
        return "vision_specialist"
    logger.info("Graph: routing to Clinical Pharmacologist (empirical path)")
    return "clinical_pharmacologist"


def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
    """Route to Trend Analyst if Vision Specialist extracted MIC values."""
    if state.get("route_to_trend_analyst"):
        logger.info("Graph: routing to Trend Analyst")
        return "trend_analyst"
    logger.info("Graph: skipping Trend Analyst (no MIC data)")
    return "clinical_pharmacologist"


def build_infection_graph() -> StateGraph:
    """Build and return the compiled LangGraph for the infection pipeline."""
    graph = StateGraph(InfectionState)

    graph.add_node("intake_historian", run_intake_historian)
    graph.add_node("vision_specialist", run_vision_specialist)
    graph.add_node("trend_analyst", run_trend_analyst)
    graph.add_node("clinical_pharmacologist", run_clinical_pharmacologist)

    graph.set_entry_point("intake_historian")

    graph.add_conditional_edges(
        "intake_historian",
        route_after_intake,
        {"vision_specialist": "vision_specialist", "clinical_pharmacologist": "clinical_pharmacologist"},
    )
    graph.add_conditional_edges(
        "vision_specialist",
        route_after_vision,
        {"trend_analyst": "trend_analyst", "clinical_pharmacologist": "clinical_pharmacologist"},
    )

    graph.add_edge("trend_analyst", "clinical_pharmacologist")
    graph.add_edge("clinical_pharmacologist", END)

    return graph


def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> InfectionState:
    """
    Run the full infection pipeline and return the final state.

    Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
    Without it, only the empirical (Stage 1) pathway runs.
    """
    labs_image_bytes: bytes | None = patient_data.get("labs_image_bytes")
    has_lab_input = bool(labs_raw_text or labs_image_bytes)

    initial_state: InfectionState = {
        "age_years": patient_data.get("age_years"),
        "weight_kg": patient_data.get("weight_kg"),
        "height_cm": patient_data.get("height_cm"),
        "sex": patient_data.get("sex"),
        "serum_creatinine_mg_dl": patient_data.get("serum_creatinine_mg_dl"),
        "medications": patient_data.get("medications", []),
        "allergies": patient_data.get("allergies", []),
        "comorbidities": patient_data.get("comorbidities", []),
        "infection_site": patient_data.get("infection_site"),
        "suspected_source": patient_data.get("suspected_source"),
        "country_or_region": patient_data.get("country_or_region"),
        "vitals": patient_data.get("vitals", {}),
        "stage": "targeted" if has_lab_input else "empirical",
        "errors": [],
        "safety_warnings": [],
    }

    if labs_raw_text:
        initial_state["labs_raw_text"] = labs_raw_text
    if labs_image_bytes:
        initial_state["labs_image_bytes"] = labs_image_bytes

    logger.info(f"Starting pipeline (stage: {initial_state['stage']}, lab_text={bool(labs_raw_text)}, lab_image={bool(labs_image_bytes)})")
    logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
    
    try:
        compiled = build_infection_graph().compile()
        logger.info("Graph compiled successfully")
        final_state = compiled.invoke(initial_state)
        logger.info("Pipeline complete")
        return final_state
    except Exception as e:
        logger.error(f"Pipeline execution failed: {e}", exc_info=True)
        initial_state["errors"].append(f"Pipeline error: {str(e)}")
        return initial_state


def run_empirical_pipeline(patient_data: dict) -> InfectionState:
    """Shorthand for run_pipeline without lab data (Stage 1)."""
    return run_pipeline(patient_data)


def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
    """Shorthand for run_pipeline with lab data (Stage 2)."""
    return run_pipeline(patient_data, labs_raw_text=labs_raw_text)


def get_graph_mermaid() -> str:
    """Return a Mermaid diagram of the graph (for documentation and debugging)."""
    try:
        return build_infection_graph().compile().get_graph().draw_mermaid()
    except Exception:
        return """
graph TD
    A[intake_historian] --> B{route_after_intake}
    B -->|targeted + lab data| C[vision_specialist]
    B -->|empirical| E[clinical_pharmacologist]
    C --> D{route_after_vision}
    D -->|has MIC data| F[trend_analyst]
    D -->|no MIC data| E
    F --> E
    E --> G[END]
"""