""" LangGraph workflow graph builder for multi-agent RAG system. """ import logging from typing import Optional, Iterator, Dict, Any import asyncio import nest_asyncio from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from utils.langgraph_state import AgentState from orchestration.nodes import ( retriever_node, analyzer_node, filter_node, synthesis_node, citation_node, finalize_node, should_continue_after_retriever, should_continue_after_filter, ) logger = logging.getLogger(__name__) # Enable nested event loops for Gradio compatibility nest_asyncio.apply() def create_workflow_graph( retriever_agent, analyzer_agent, synthesis_agent, citation_agent, use_checkpointing: bool = True, ) -> Any: """ Create LangGraph workflow for multi-agent RAG system. Args: retriever_agent: RetrieverAgent instance analyzer_agent: AnalyzerAgent instance synthesis_agent: SynthesisAgent instance citation_agent: CitationAgent instance use_checkpointing: Whether to enable workflow checkpointing Returns: Compiled LangGraph application """ logger.info("Creating LangGraph workflow graph") # Create state graph workflow = StateGraph(AgentState) # Add nodes with agent instances bound workflow.add_node( "retriever", lambda state: retriever_node(state, retriever_agent) ) workflow.add_node( "analyzer", lambda state: analyzer_node(state, analyzer_agent) ) workflow.add_node( "filter", filter_node ) workflow.add_node( "synthesis", lambda state: synthesis_node(state, synthesis_agent) ) workflow.add_node( "citation", lambda state: citation_node(state, citation_agent) ) workflow.add_node( "finalize", finalize_node ) # Set entry point workflow.set_entry_point("retriever") # Add conditional edge after retriever workflow.add_conditional_edges( "retriever", should_continue_after_retriever, { "continue": "analyzer", "end": END, } ) # Add edge from analyzer to filter workflow.add_edge("analyzer", "filter") # Add conditional edge after filter workflow.add_conditional_edges( "filter", should_continue_after_filter, { "continue": "synthesis", "end": END, } ) # Add edges for synthesis, citation, and finalize workflow.add_edge("synthesis", "citation") workflow.add_edge("citation", "finalize") workflow.add_edge("finalize", END) # Compile workflow if use_checkpointing: checkpointer = MemorySaver() app = workflow.compile(checkpointer=checkpointer) logger.info("Workflow compiled with checkpointing enabled") else: app = workflow.compile() logger.info("Workflow compiled without checkpointing") return app async def run_workflow_async( app: Any, initial_state: AgentState, thread_id: Optional[str] = None, ) -> Iterator[AgentState]: """ Run LangGraph workflow asynchronously with streaming. Args: app: Compiled LangGraph application initial_state: Initial workflow state thread_id: Optional thread ID for checkpointing Yields: State updates after each node execution """ config = {"configurable": {"thread_id": thread_id or "default"}} logger.info(f"Starting async workflow execution (thread_id: {thread_id})") try: async for event in app.astream(initial_state, config=config): # Event is a dict with node name as key for node_name, node_state in event.items(): logger.debug(f"Node '{node_name}' completed") yield node_state except Exception as e: logger.error(f"Error during workflow execution: {e}") # Yield error state initial_state["errors"].append(f"Workflow error: {str(e)}") yield initial_state def _run_workflow_streaming( app: Any, initial_state: AgentState, thread_id: Optional[str] = None, ) -> Iterator[AgentState]: """ Run LangGraph workflow with streaming (internal generator function). Args: app: Compiled LangGraph application initial_state: Initial workflow state thread_id: Optional thread ID for checkpointing Yields: State updates after each node execution """ # Create new event loop for streaming loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: async def stream_wrapper(): async for state in run_workflow_async(app, initial_state, thread_id): yield state async_gen = stream_wrapper() # Convert async generator to sync generator while True: try: yield loop.run_until_complete(async_gen.__anext__()) except StopAsyncIteration: break finally: loop.close() def run_workflow( app: Any, initial_state: AgentState, thread_id: Optional[str] = None, use_streaming: bool = False, ) -> Any: """ Run LangGraph workflow (sync wrapper for Gradio compatibility). Args: app: Compiled LangGraph application initial_state: Initial workflow state thread_id: Optional thread ID for checkpointing use_streaming: Whether to stream intermediate results Returns: Final state (if use_streaming=False) or generator of states (if use_streaming=True) """ config = {"configurable": {"thread_id": thread_id or "default"}} logger.info(f"Starting workflow execution (thread_id: {thread_id}, streaming: {use_streaming})") try: if use_streaming: # Return generator for streaming return _run_workflow_streaming(app, initial_state, thread_id) else: # Non-streaming execution - just return final state final_state = app.invoke(initial_state, config=config) logger.info("Workflow execution completed") return final_state except Exception as e: logger.error(f"Error during workflow execution: {e}") initial_state["errors"].append(f"Workflow execution error: {str(e)}") return initial_state def get_workflow_state( app: Any, thread_id: str, ) -> Optional[AgentState]: """ Get current state of a workflow execution. Args: app: Compiled LangGraph application thread_id: Thread ID of the workflow Returns: Current state or None if not found """ try: config = {"configurable": {"thread_id": thread_id}} state = app.get_state(config) return state.values if state else None except Exception as e: logger.error(f"Error getting workflow state: {e}") return None