Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,468 Bytes
c0fff99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
"""
LangGraph pipeline: linear flow through 4 diagnostic agents.
START → diagnostician → bias_detector → devil_advocate → consultant → END
"""
import logging
import threading
from agents.state import PipelineState
from agents import diagnostician, bias_detector, devil_advocate, consultant
logger = logging.getLogger(__name__)
try:
from langgraph.graph import StateGraph, START, END
_LANGGRAPH_AVAILABLE = True
except ModuleNotFoundError:
StateGraph = None # type: ignore[assignment]
START = END = None
_LANGGRAPH_AVAILABLE = False
def _check_error(state: PipelineState) -> str:
"""Route to END if an error occurred, otherwise continue."""
if state.get("error"):
return "end"
return "continue"
class _FallbackGraph:
def invoke(self, initial_state: PipelineState) -> PipelineState:
state = initial_state
for fn in (diagnostician.run, bias_detector.run, devil_advocate.run, consultant.run):
state = fn(state)
if state.get("error"):
break
return state
def stream(self, initial_state: PipelineState, stream_mode: str = "updates"):
state = initial_state
for name, fn in (
("diagnostician", diagnostician.run),
("bias_detector", bias_detector.run),
("devil_advocate", devil_advocate.run),
("consultant", consultant.run),
):
state = fn(state)
yield {name: dict(state)}
if state.get("error"):
break
def build_graph():
"""Build and compile the diagnostic debiasing pipeline."""
if not _LANGGRAPH_AVAILABLE:
logger.warning("langgraph is not installed; falling back to a simple sequential pipeline.")
return _FallbackGraph()
graph = StateGraph(PipelineState)
# Add nodes
graph.add_node("diagnostician", diagnostician.run)
graph.add_node("bias_detector", bias_detector.run)
graph.add_node("devil_advocate", devil_advocate.run)
graph.add_node("consultant", consultant.run)
# Linear flow with error checking
graph.add_edge(START, "diagnostician")
graph.add_conditional_edges("diagnostician", _check_error, {"continue": "bias_detector", "end": END})
graph.add_conditional_edges("bias_detector", _check_error, {"continue": "devil_advocate", "end": END})
graph.add_conditional_edges("devil_advocate", _check_error, {"continue": "consultant", "end": END})
graph.add_edge("consultant", END)
return graph.compile()
# Singleton compiled graph
_compiled_graph = None
_compiled_graph_lock = threading.Lock()
def get_graph():
"""Get or create the compiled pipeline graph."""
global _compiled_graph
if _compiled_graph is not None:
return _compiled_graph
with _compiled_graph_lock:
if _compiled_graph is None:
_compiled_graph = build_graph()
return _compiled_graph
def _make_initial_state(
image,
doctor_diagnosis: str,
clinical_context: str,
modality: str | None = None,
) -> PipelineState:
return {
"clinical_input": {
"image": image,
"doctor_diagnosis": doctor_diagnosis,
"clinical_context": clinical_context,
"modality": modality or "CXR",
},
"diagnostician_output": None,
"bias_detector_output": None,
"devils_advocate_output": None,
"consultant_output": None,
"current_step": "start",
"error": None,
}
def run_pipeline(
image,
doctor_diagnosis: str,
clinical_context: str,
modality: str | None = None,
) -> PipelineState:
"""Run the full debiasing pipeline (blocking)."""
graph = get_graph()
initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
return graph.invoke(initial_state)
def stream_pipeline(
image,
doctor_diagnosis: str,
clinical_context: str,
modality: str | None = None,
):
"""
Stream the pipeline, yielding (node_name, state) after each agent completes.
Use this for progressive UI updates.
"""
graph = get_graph()
initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
for event in graph.stream(initial_state, stream_mode="updates"):
# event is {node_name: state_update}
for node_name, state_update in event.items():
yield node_name, state_update
|