Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| LangGraph pipeline: linear flow through 4 diagnostic agents. | |
| START → diagnostician → bias_detector → devil_advocate → consultant → END | |
| """ | |
| import logging | |
| import threading | |
| from agents.state import PipelineState | |
| from agents import diagnostician, bias_detector, devil_advocate, consultant | |
| logger = logging.getLogger(__name__) | |
| try: | |
| from langgraph.graph import StateGraph, START, END | |
| _LANGGRAPH_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| StateGraph = None # type: ignore[assignment] | |
| START = END = None | |
| _LANGGRAPH_AVAILABLE = False | |
| def _check_error(state: PipelineState) -> str: | |
| """Route to END if an error occurred, otherwise continue.""" | |
| if state.get("error"): | |
| return "end" | |
| return "continue" | |
| class _FallbackGraph: | |
| def invoke(self, initial_state: PipelineState) -> PipelineState: | |
| state = initial_state | |
| for fn in (diagnostician.run, bias_detector.run, devil_advocate.run, consultant.run): | |
| state = fn(state) | |
| if state.get("error"): | |
| break | |
| return state | |
| def stream(self, initial_state: PipelineState, stream_mode: str = "updates"): | |
| state = initial_state | |
| for name, fn in ( | |
| ("diagnostician", diagnostician.run), | |
| ("bias_detector", bias_detector.run), | |
| ("devil_advocate", devil_advocate.run), | |
| ("consultant", consultant.run), | |
| ): | |
| state = fn(state) | |
| yield {name: dict(state)} | |
| if state.get("error"): | |
| break | |
| def build_graph(): | |
| """Build and compile the diagnostic debiasing pipeline.""" | |
| if not _LANGGRAPH_AVAILABLE: | |
| logger.warning("langgraph is not installed; falling back to a simple sequential pipeline.") | |
| return _FallbackGraph() | |
| graph = StateGraph(PipelineState) | |
| # Add nodes | |
| graph.add_node("diagnostician", diagnostician.run) | |
| graph.add_node("bias_detector", bias_detector.run) | |
| graph.add_node("devil_advocate", devil_advocate.run) | |
| graph.add_node("consultant", consultant.run) | |
| # Linear flow with error checking | |
| graph.add_edge(START, "diagnostician") | |
| graph.add_conditional_edges("diagnostician", _check_error, {"continue": "bias_detector", "end": END}) | |
| graph.add_conditional_edges("bias_detector", _check_error, {"continue": "devil_advocate", "end": END}) | |
| graph.add_conditional_edges("devil_advocate", _check_error, {"continue": "consultant", "end": END}) | |
| graph.add_edge("consultant", END) | |
| return graph.compile() | |
| # Singleton compiled graph | |
| _compiled_graph = None | |
| _compiled_graph_lock = threading.Lock() | |
| def get_graph(): | |
| """Get or create the compiled pipeline graph.""" | |
| global _compiled_graph | |
| if _compiled_graph is not None: | |
| return _compiled_graph | |
| with _compiled_graph_lock: | |
| if _compiled_graph is None: | |
| _compiled_graph = build_graph() | |
| return _compiled_graph | |
| def _make_initial_state( | |
| image, | |
| doctor_diagnosis: str, | |
| clinical_context: str, | |
| modality: str | None = None, | |
| ) -> PipelineState: | |
| return { | |
| "clinical_input": { | |
| "image": image, | |
| "doctor_diagnosis": doctor_diagnosis, | |
| "clinical_context": clinical_context, | |
| "modality": modality or "CXR", | |
| }, | |
| "diagnostician_output": None, | |
| "bias_detector_output": None, | |
| "devils_advocate_output": None, | |
| "consultant_output": None, | |
| "current_step": "start", | |
| "error": None, | |
| } | |
| def run_pipeline( | |
| image, | |
| doctor_diagnosis: str, | |
| clinical_context: str, | |
| modality: str | None = None, | |
| ) -> PipelineState: | |
| """Run the full debiasing pipeline (blocking).""" | |
| graph = get_graph() | |
| initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality) | |
| return graph.invoke(initial_state) | |
| def stream_pipeline( | |
| image, | |
| doctor_diagnosis: str, | |
| clinical_context: str, | |
| modality: str | None = None, | |
| ): | |
| """ | |
| Stream the pipeline, yielding (node_name, state) after each agent completes. | |
| Use this for progressive UI updates. | |
| """ | |
| graph = get_graph() | |
| initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality) | |
| for event in graph.stream(initial_state, stream_mode="updates"): | |
| # event is {node_name: state_update} | |
| for node_name, state_update in event.items(): | |
| yield node_name, state_update | |