File size: 7,023 Bytes
aca8ab4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
"""
LangGraph workflow graph builder for multi-agent RAG system.
"""
import logging
from typing import Optional, Iterator, Dict, Any
import asyncio
import nest_asyncio
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from utils.langgraph_state import AgentState
from orchestration.nodes import (
retriever_node,
analyzer_node,
filter_node,
synthesis_node,
citation_node,
finalize_node,
should_continue_after_retriever,
should_continue_after_filter,
)
logger = logging.getLogger(__name__)
# Enable nested event loops for Gradio compatibility
nest_asyncio.apply()
def create_workflow_graph(
retriever_agent,
analyzer_agent,
synthesis_agent,
citation_agent,
use_checkpointing: bool = True,
) -> Any:
"""
Create LangGraph workflow for multi-agent RAG system.
Args:
retriever_agent: RetrieverAgent instance
analyzer_agent: AnalyzerAgent instance
synthesis_agent: SynthesisAgent instance
citation_agent: CitationAgent instance
use_checkpointing: Whether to enable workflow checkpointing
Returns:
Compiled LangGraph application
"""
logger.info("Creating LangGraph workflow graph")
# Create state graph
workflow = StateGraph(AgentState)
# Add nodes with agent instances bound
workflow.add_node(
"retriever",
lambda state: retriever_node(state, retriever_agent)
)
workflow.add_node(
"analyzer",
lambda state: analyzer_node(state, analyzer_agent)
)
workflow.add_node(
"filter",
filter_node
)
workflow.add_node(
"synthesis",
lambda state: synthesis_node(state, synthesis_agent)
)
workflow.add_node(
"citation",
lambda state: citation_node(state, citation_agent)
)
workflow.add_node(
"finalize",
finalize_node
)
# Set entry point
workflow.set_entry_point("retriever")
# Add conditional edge after retriever
workflow.add_conditional_edges(
"retriever",
should_continue_after_retriever,
{
"continue": "analyzer",
"end": END,
}
)
# Add edge from analyzer to filter
workflow.add_edge("analyzer", "filter")
# Add conditional edge after filter
workflow.add_conditional_edges(
"filter",
should_continue_after_filter,
{
"continue": "synthesis",
"end": END,
}
)
# Add edges for synthesis, citation, and finalize
workflow.add_edge("synthesis", "citation")
workflow.add_edge("citation", "finalize")
workflow.add_edge("finalize", END)
# Compile workflow
if use_checkpointing:
checkpointer = MemorySaver()
app = workflow.compile(checkpointer=checkpointer)
logger.info("Workflow compiled with checkpointing enabled")
else:
app = workflow.compile()
logger.info("Workflow compiled without checkpointing")
return app
async def run_workflow_async(
app: Any,
initial_state: AgentState,
thread_id: Optional[str] = None,
) -> Iterator[AgentState]:
"""
Run LangGraph workflow asynchronously with streaming.
Args:
app: Compiled LangGraph application
initial_state: Initial workflow state
thread_id: Optional thread ID for checkpointing
Yields:
State updates after each node execution
"""
config = {"configurable": {"thread_id": thread_id or "default"}}
logger.info(f"Starting async workflow execution (thread_id: {thread_id})")
try:
async for event in app.astream(initial_state, config=config):
# Event is a dict with node name as key
for node_name, node_state in event.items():
logger.debug(f"Node '{node_name}' completed")
yield node_state
except Exception as e:
logger.error(f"Error during workflow execution: {e}")
# Yield error state
initial_state["errors"].append(f"Workflow error: {str(e)}")
yield initial_state
def _run_workflow_streaming(
app: Any,
initial_state: AgentState,
thread_id: Optional[str] = None,
) -> Iterator[AgentState]:
"""
Run LangGraph workflow with streaming (internal generator function).
Args:
app: Compiled LangGraph application
initial_state: Initial workflow state
thread_id: Optional thread ID for checkpointing
Yields:
State updates after each node execution
"""
# Create new event loop for streaming
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def stream_wrapper():
async for state in run_workflow_async(app, initial_state, thread_id):
yield state
async_gen = stream_wrapper()
# Convert async generator to sync generator
while True:
try:
yield loop.run_until_complete(async_gen.__anext__())
except StopAsyncIteration:
break
finally:
loop.close()
def run_workflow(
app: Any,
initial_state: AgentState,
thread_id: Optional[str] = None,
use_streaming: bool = False,
) -> Any:
"""
Run LangGraph workflow (sync wrapper for Gradio compatibility).
Args:
app: Compiled LangGraph application
initial_state: Initial workflow state
thread_id: Optional thread ID for checkpointing
use_streaming: Whether to stream intermediate results
Returns:
Final state (if use_streaming=False) or generator of states (if use_streaming=True)
"""
config = {"configurable": {"thread_id": thread_id or "default"}}
logger.info(f"Starting workflow execution (thread_id: {thread_id}, streaming: {use_streaming})")
try:
if use_streaming:
# Return generator for streaming
return _run_workflow_streaming(app, initial_state, thread_id)
else:
# Non-streaming execution - just return final state
final_state = app.invoke(initial_state, config=config)
logger.info("Workflow execution completed")
return final_state
except Exception as e:
logger.error(f"Error during workflow execution: {e}")
initial_state["errors"].append(f"Workflow execution error: {str(e)}")
return initial_state
def get_workflow_state(
app: Any,
thread_id: str,
) -> Optional[AgentState]:
"""
Get current state of a workflow execution.
Args:
app: Compiled LangGraph application
thread_id: Thread ID of the workflow
Returns:
Current state or None if not found
"""
try:
config = {"configurable": {"thread_id": thread_id}}
state = app.get_state(config)
return state.values if state else None
except Exception as e:
logger.error(f"Error getting workflow state: {e}")
return None
|