File size: 4,468 Bytes
c0fff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
LangGraph pipeline: linear flow through 4 diagnostic agents.

START → diagnostician → bias_detector → devil_advocate → consultant → END
"""

import logging
import threading

from agents.state import PipelineState
from agents import diagnostician, bias_detector, devil_advocate, consultant

logger = logging.getLogger(__name__)

try:
    from langgraph.graph import StateGraph, START, END

    _LANGGRAPH_AVAILABLE = True
except ModuleNotFoundError:
    StateGraph = None  # type: ignore[assignment]
    START = END = None
    _LANGGRAPH_AVAILABLE = False


def _check_error(state: PipelineState) -> str:
    """Route to END if an error occurred, otherwise continue."""
    if state.get("error"):
        return "end"
    return "continue"


class _FallbackGraph:
    def invoke(self, initial_state: PipelineState) -> PipelineState:
        state = initial_state
        for fn in (diagnostician.run, bias_detector.run, devil_advocate.run, consultant.run):
            state = fn(state)
            if state.get("error"):
                break
        return state

    def stream(self, initial_state: PipelineState, stream_mode: str = "updates"):
        state = initial_state
        for name, fn in (
            ("diagnostician", diagnostician.run),
            ("bias_detector", bias_detector.run),
            ("devil_advocate", devil_advocate.run),
            ("consultant", consultant.run),
        ):
            state = fn(state)
            yield {name: dict(state)}
            if state.get("error"):
                break


def build_graph():
    """Build and compile the diagnostic debiasing pipeline."""
    if not _LANGGRAPH_AVAILABLE:
        logger.warning("langgraph is not installed; falling back to a simple sequential pipeline.")
        return _FallbackGraph()

    graph = StateGraph(PipelineState)

    # Add nodes
    graph.add_node("diagnostician", diagnostician.run)
    graph.add_node("bias_detector", bias_detector.run)
    graph.add_node("devil_advocate", devil_advocate.run)
    graph.add_node("consultant", consultant.run)

    # Linear flow with error checking
    graph.add_edge(START, "diagnostician")
    graph.add_conditional_edges("diagnostician", _check_error, {"continue": "bias_detector", "end": END})
    graph.add_conditional_edges("bias_detector", _check_error, {"continue": "devil_advocate", "end": END})
    graph.add_conditional_edges("devil_advocate", _check_error, {"continue": "consultant", "end": END})
    graph.add_edge("consultant", END)

    return graph.compile()


# Singleton compiled graph
_compiled_graph = None
_compiled_graph_lock = threading.Lock()


def get_graph():
    """Get or create the compiled pipeline graph."""
    global _compiled_graph
    if _compiled_graph is not None:
        return _compiled_graph
    with _compiled_graph_lock:
        if _compiled_graph is None:
            _compiled_graph = build_graph()
    return _compiled_graph


def _make_initial_state(
    image,
    doctor_diagnosis: str,
    clinical_context: str,
    modality: str | None = None,
) -> PipelineState:
    return {
        "clinical_input": {
            "image": image,
            "doctor_diagnosis": doctor_diagnosis,
            "clinical_context": clinical_context,
            "modality": modality or "CXR",
        },
        "diagnostician_output": None,
        "bias_detector_output": None,
        "devils_advocate_output": None,
        "consultant_output": None,
        "current_step": "start",
        "error": None,
    }


def run_pipeline(
    image,
    doctor_diagnosis: str,
    clinical_context: str,
    modality: str | None = None,
) -> PipelineState:
    """Run the full debiasing pipeline (blocking)."""
    graph = get_graph()
    initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
    return graph.invoke(initial_state)


def stream_pipeline(
    image,
    doctor_diagnosis: str,
    clinical_context: str,
    modality: str | None = None,
):
    """
    Stream the pipeline, yielding (node_name, state) after each agent completes.
    Use this for progressive UI updates.
    """
    graph = get_graph()
    initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)

    for event in graph.stream(initial_state, stream_mode="updates"):
        # event is {node_name: state_update}
        for node_name, state_update in event.items():
            yield node_name, state_update