yipengsun's picture
Initial commit: Diagnostic Devil's Advocate project
c0fff99
"""
LangGraph pipeline: linear flow through 4 diagnostic agents.
START → diagnostician → bias_detector → devil_advocate → consultant → END
"""
import logging
import threading
from agents.state import PipelineState
from agents import diagnostician, bias_detector, devil_advocate, consultant
logger = logging.getLogger(__name__)
try:
from langgraph.graph import StateGraph, START, END
_LANGGRAPH_AVAILABLE = True
except ModuleNotFoundError:
StateGraph = None # type: ignore[assignment]
START = END = None
_LANGGRAPH_AVAILABLE = False
def _check_error(state: PipelineState) -> str:
"""Route to END if an error occurred, otherwise continue."""
if state.get("error"):
return "end"
return "continue"
class _FallbackGraph:
def invoke(self, initial_state: PipelineState) -> PipelineState:
state = initial_state
for fn in (diagnostician.run, bias_detector.run, devil_advocate.run, consultant.run):
state = fn(state)
if state.get("error"):
break
return state
def stream(self, initial_state: PipelineState, stream_mode: str = "updates"):
state = initial_state
for name, fn in (
("diagnostician", diagnostician.run),
("bias_detector", bias_detector.run),
("devil_advocate", devil_advocate.run),
("consultant", consultant.run),
):
state = fn(state)
yield {name: dict(state)}
if state.get("error"):
break
def build_graph():
"""Build and compile the diagnostic debiasing pipeline."""
if not _LANGGRAPH_AVAILABLE:
logger.warning("langgraph is not installed; falling back to a simple sequential pipeline.")
return _FallbackGraph()
graph = StateGraph(PipelineState)
# Add nodes
graph.add_node("diagnostician", diagnostician.run)
graph.add_node("bias_detector", bias_detector.run)
graph.add_node("devil_advocate", devil_advocate.run)
graph.add_node("consultant", consultant.run)
# Linear flow with error checking
graph.add_edge(START, "diagnostician")
graph.add_conditional_edges("diagnostician", _check_error, {"continue": "bias_detector", "end": END})
graph.add_conditional_edges("bias_detector", _check_error, {"continue": "devil_advocate", "end": END})
graph.add_conditional_edges("devil_advocate", _check_error, {"continue": "consultant", "end": END})
graph.add_edge("consultant", END)
return graph.compile()
# Singleton compiled graph
_compiled_graph = None
_compiled_graph_lock = threading.Lock()
def get_graph():
"""Get or create the compiled pipeline graph."""
global _compiled_graph
if _compiled_graph is not None:
return _compiled_graph
with _compiled_graph_lock:
if _compiled_graph is None:
_compiled_graph = build_graph()
return _compiled_graph
def _make_initial_state(
image,
doctor_diagnosis: str,
clinical_context: str,
modality: str | None = None,
) -> PipelineState:
return {
"clinical_input": {
"image": image,
"doctor_diagnosis": doctor_diagnosis,
"clinical_context": clinical_context,
"modality": modality or "CXR",
},
"diagnostician_output": None,
"bias_detector_output": None,
"devils_advocate_output": None,
"consultant_output": None,
"current_step": "start",
"error": None,
}
def run_pipeline(
image,
doctor_diagnosis: str,
clinical_context: str,
modality: str | None = None,
) -> PipelineState:
"""Run the full debiasing pipeline (blocking)."""
graph = get_graph()
initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
return graph.invoke(initial_state)
def stream_pipeline(
image,
doctor_diagnosis: str,
clinical_context: str,
modality: str | None = None,
):
"""
Stream the pipeline, yielding (node_name, state) after each agent completes.
Use this for progressive UI updates.
"""
graph = get_graph()
initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
for event in graph.stream(initial_state, stream_mode="updates"):
# event is {node_name: state_update}
for node_name, state_update in event.items():
yield node_name, state_update