|
|
""" |
|
|
LangGraph workflow graph builder for multi-agent RAG system. |
|
|
""" |
|
|
import logging |
|
|
from typing import Optional, Iterator, Dict, Any |
|
|
import asyncio |
|
|
import nest_asyncio |
|
|
|
|
|
from langgraph.graph import StateGraph, END |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
from utils.langgraph_state import AgentState |
|
|
from orchestration.nodes import ( |
|
|
retriever_node, |
|
|
analyzer_node, |
|
|
filter_node, |
|
|
synthesis_node, |
|
|
citation_node, |
|
|
finalize_node, |
|
|
should_continue_after_retriever, |
|
|
should_continue_after_filter, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
nest_asyncio.apply() |
|
|
|
|
|
|
|
|
def create_workflow_graph( |
|
|
retriever_agent, |
|
|
analyzer_agent, |
|
|
synthesis_agent, |
|
|
citation_agent, |
|
|
use_checkpointing: bool = True, |
|
|
) -> Any: |
|
|
""" |
|
|
Create LangGraph workflow for multi-agent RAG system. |
|
|
|
|
|
Args: |
|
|
retriever_agent: RetrieverAgent instance |
|
|
analyzer_agent: AnalyzerAgent instance |
|
|
synthesis_agent: SynthesisAgent instance |
|
|
citation_agent: CitationAgent instance |
|
|
use_checkpointing: Whether to enable workflow checkpointing |
|
|
|
|
|
Returns: |
|
|
Compiled LangGraph application |
|
|
""" |
|
|
logger.info("Creating LangGraph workflow graph") |
|
|
|
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
workflow.add_node( |
|
|
"retriever", |
|
|
lambda state: retriever_node(state, retriever_agent) |
|
|
) |
|
|
|
|
|
workflow.add_node( |
|
|
"analyzer", |
|
|
lambda state: analyzer_node(state, analyzer_agent) |
|
|
) |
|
|
|
|
|
workflow.add_node( |
|
|
"filter", |
|
|
filter_node |
|
|
) |
|
|
|
|
|
workflow.add_node( |
|
|
"synthesis", |
|
|
lambda state: synthesis_node(state, synthesis_agent) |
|
|
) |
|
|
|
|
|
workflow.add_node( |
|
|
"citation", |
|
|
lambda state: citation_node(state, citation_agent) |
|
|
) |
|
|
|
|
|
workflow.add_node( |
|
|
"finalize", |
|
|
finalize_node |
|
|
) |
|
|
|
|
|
|
|
|
workflow.set_entry_point("retriever") |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"retriever", |
|
|
should_continue_after_retriever, |
|
|
{ |
|
|
"continue": "analyzer", |
|
|
"end": END, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("analyzer", "filter") |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"filter", |
|
|
should_continue_after_filter, |
|
|
{ |
|
|
"continue": "synthesis", |
|
|
"end": END, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("synthesis", "citation") |
|
|
workflow.add_edge("citation", "finalize") |
|
|
workflow.add_edge("finalize", END) |
|
|
|
|
|
|
|
|
if use_checkpointing: |
|
|
checkpointer = MemorySaver() |
|
|
app = workflow.compile(checkpointer=checkpointer) |
|
|
logger.info("Workflow compiled with checkpointing enabled") |
|
|
else: |
|
|
app = workflow.compile() |
|
|
logger.info("Workflow compiled without checkpointing") |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
async def run_workflow_async( |
|
|
app: Any, |
|
|
initial_state: AgentState, |
|
|
thread_id: Optional[str] = None, |
|
|
) -> Iterator[AgentState]: |
|
|
""" |
|
|
Run LangGraph workflow asynchronously with streaming. |
|
|
|
|
|
Args: |
|
|
app: Compiled LangGraph application |
|
|
initial_state: Initial workflow state |
|
|
thread_id: Optional thread ID for checkpointing |
|
|
|
|
|
Yields: |
|
|
State updates after each node execution |
|
|
""" |
|
|
config = {"configurable": {"thread_id": thread_id or "default"}} |
|
|
|
|
|
logger.info(f"Starting async workflow execution (thread_id: {thread_id})") |
|
|
|
|
|
try: |
|
|
async for event in app.astream(initial_state, config=config): |
|
|
|
|
|
for node_name, node_state in event.items(): |
|
|
logger.debug(f"Node '{node_name}' completed") |
|
|
yield node_state |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during workflow execution: {e}") |
|
|
|
|
|
initial_state["errors"].append(f"Workflow error: {str(e)}") |
|
|
yield initial_state |
|
|
|
|
|
|
|
|
def _run_workflow_streaming( |
|
|
app: Any, |
|
|
initial_state: AgentState, |
|
|
thread_id: Optional[str] = None, |
|
|
) -> Iterator[AgentState]: |
|
|
""" |
|
|
Run LangGraph workflow with streaming (internal generator function). |
|
|
|
|
|
Args: |
|
|
app: Compiled LangGraph application |
|
|
initial_state: Initial workflow state |
|
|
thread_id: Optional thread ID for checkpointing |
|
|
|
|
|
Yields: |
|
|
State updates after each node execution |
|
|
""" |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
try: |
|
|
async def stream_wrapper(): |
|
|
async for state in run_workflow_async(app, initial_state, thread_id): |
|
|
yield state |
|
|
|
|
|
async_gen = stream_wrapper() |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
yield loop.run_until_complete(async_gen.__anext__()) |
|
|
except StopAsyncIteration: |
|
|
break |
|
|
finally: |
|
|
loop.close() |
|
|
|
|
|
|
|
|
def run_workflow( |
|
|
app: Any, |
|
|
initial_state: AgentState, |
|
|
thread_id: Optional[str] = None, |
|
|
use_streaming: bool = False, |
|
|
) -> Any: |
|
|
""" |
|
|
Run LangGraph workflow (sync wrapper for Gradio compatibility). |
|
|
|
|
|
Args: |
|
|
app: Compiled LangGraph application |
|
|
initial_state: Initial workflow state |
|
|
thread_id: Optional thread ID for checkpointing |
|
|
use_streaming: Whether to stream intermediate results |
|
|
|
|
|
Returns: |
|
|
Final state (if use_streaming=False) or generator of states (if use_streaming=True) |
|
|
""" |
|
|
config = {"configurable": {"thread_id": thread_id or "default"}} |
|
|
|
|
|
logger.info(f"Starting workflow execution (thread_id: {thread_id}, streaming: {use_streaming})") |
|
|
|
|
|
try: |
|
|
if use_streaming: |
|
|
|
|
|
return _run_workflow_streaming(app, initial_state, thread_id) |
|
|
else: |
|
|
|
|
|
final_state = app.invoke(initial_state, config=config) |
|
|
logger.info("Workflow execution completed") |
|
|
return final_state |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during workflow execution: {e}") |
|
|
initial_state["errors"].append(f"Workflow execution error: {str(e)}") |
|
|
return initial_state |
|
|
|
|
|
|
|
|
def get_workflow_state( |
|
|
app: Any, |
|
|
thread_id: str, |
|
|
) -> Optional[AgentState]: |
|
|
""" |
|
|
Get current state of a workflow execution. |
|
|
|
|
|
Args: |
|
|
app: Compiled LangGraph application |
|
|
thread_id: Thread ID of the workflow |
|
|
|
|
|
Returns: |
|
|
Current state or None if not found |
|
|
""" |
|
|
try: |
|
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
state = app.get_state(config) |
|
|
return state.values if state else None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting workflow state: {e}") |
|
|
return None |
|
|
|