|
|
""" |
|
|
LangGraph Multi-Agent System Implementation |
|
|
|
|
|
This module implements a multi-agent system using LangGraph with the following components: |
|
|
- LeadAgent: Orchestrates the workflow and makes decisions |
|
|
- ResearchAgent: Handles information gathering and research tasks |
|
|
- CodeAgent: Handles computational and code execution tasks |
|
|
- AnswerFormatter: Formats final answers according to GAIA requirements |
|
|
- Memory: Persistent storage for context and learning |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Dict, Any, TypedDict, Literal, Annotated, List |
|
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage |
|
|
from langgraph.graph import StateGraph, START, END |
|
|
from langgraph.types import Command |
|
|
import operator |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
from observability import ( |
|
|
start_root_span, |
|
|
get_callback_handler, |
|
|
flush_traces, |
|
|
shutdown_observability |
|
|
) |
|
|
|
|
|
|
|
|
load_dotenv("env.local") |
|
|
|
|
|
class AgentState(TypedDict): |
|
|
""" |
|
|
State schema for the multi-agent system following LangGraph best practices. |
|
|
Treats every agent node as a pure function AgentState → Command. |
|
|
""" |
|
|
|
|
|
messages: Annotated[List[BaseMessage], operator.add] |
|
|
|
|
|
|
|
|
draft_answer: str |
|
|
research_notes: Annotated[str, operator.add] |
|
|
code_outputs: Annotated[str, operator.add] |
|
|
|
|
|
|
|
|
loop_counter: int |
|
|
max_iterations: int |
|
|
|
|
|
|
|
|
next: Literal["research", "code", "formatter", "__end__"] |
|
|
|
|
|
|
|
|
final_answer: str |
|
|
|
|
|
|
|
|
user_id: str |
|
|
session_id: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_agent_graph(): |
|
|
""" |
|
|
Create the LangGraph workflow following the specified architecture: |
|
|
lead -> research -> code -> lead (loop) -> formatter -> END |
|
|
""" |
|
|
from agents.lead_agent import lead_agent |
|
|
from agents.research_agent import research_agent |
|
|
from agents.code_agent import code_agent |
|
|
from agents.answer_formatter import answer_formatter |
|
|
|
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
workflow.add_node("lead", lead_agent) |
|
|
workflow.add_node("research", research_agent) |
|
|
workflow.add_node("code", code_agent) |
|
|
workflow.add_node("formatter", answer_formatter) |
|
|
|
|
|
|
|
|
workflow.add_edge(START, "lead") |
|
|
|
|
|
|
|
|
def route_from_lead(state: AgentState) -> str: |
|
|
"""Route from lead agent based on the 'next' field""" |
|
|
|
|
|
if (state.get("loop_counter", 0) >= state.get("max_iterations", 3) or |
|
|
state.get("final_answer")): |
|
|
return "__end__" |
|
|
return state.get("next", "research") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"lead", |
|
|
route_from_lead, |
|
|
{ |
|
|
"research": "research", |
|
|
"code": "code", |
|
|
"formatter": "formatter", |
|
|
"__end__": END |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("research", "lead") |
|
|
workflow.add_edge("code", "lead") |
|
|
workflow.add_edge("formatter", END) |
|
|
|
|
|
return workflow |
|
|
|
|
|
|
|
|
async def run_agent_system( |
|
|
query: str, |
|
|
user_id: str = "default_user", |
|
|
session_id: str = "default_session", |
|
|
max_iterations: int = 3 |
|
|
) -> str: |
|
|
""" |
|
|
Main entry point for the agent system. |
|
|
|
|
|
Args: |
|
|
query: User question to answer |
|
|
user_id: User identifier for tracing |
|
|
session_id: Session identifier for tracing |
|
|
max_iterations: Maximum number of research/code loops |
|
|
|
|
|
Returns: |
|
|
Final formatted answer |
|
|
""" |
|
|
try: |
|
|
|
|
|
callback_handler = get_callback_handler() |
|
|
|
|
|
|
|
|
with start_root_span( |
|
|
name="user-request", |
|
|
user_id=user_id, |
|
|
session_id=session_id, |
|
|
metadata={"query": query, "max_iterations": max_iterations} |
|
|
) as root_span: |
|
|
|
|
|
|
|
|
workflow = create_agent_graph() |
|
|
app = workflow.compile() |
|
|
|
|
|
|
|
|
initial_state: AgentState = { |
|
|
"messages": [HumanMessage(content=query)], |
|
|
"draft_answer": "", |
|
|
"research_notes": "", |
|
|
"code_outputs": "", |
|
|
"loop_counter": 0, |
|
|
"max_iterations": max_iterations, |
|
|
"next": "research", |
|
|
"final_answer": "", |
|
|
"user_id": user_id, |
|
|
"session_id": session_id |
|
|
} |
|
|
|
|
|
|
|
|
if callback_handler: |
|
|
final_state = await app.ainvoke( |
|
|
initial_state, |
|
|
config={"callbacks": [callback_handler]} |
|
|
) |
|
|
else: |
|
|
print("Warning: Running without Langfuse tracing") |
|
|
final_state = await app.ainvoke(initial_state) |
|
|
|
|
|
|
|
|
if root_span: |
|
|
root_span.update_trace(output={"final_answer": final_state["final_answer"]}) |
|
|
|
|
|
return final_state["final_answer"] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in agent system: {e}") |
|
|
return f"I apologize, but I encountered an error while processing your query: {str(e)}" |
|
|
|
|
|
finally: |
|
|
|
|
|
flush_traces(background=True) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import asyncio |
|
|
|
|
|
|
|
|
async def test(): |
|
|
result = await run_agent_system( |
|
|
"What is the capital of Maharashtra?", |
|
|
user_id="test_user", |
|
|
session_id="test_session" |
|
|
) |
|
|
print(f"Final Answer: {result}") |
|
|
|
|
|
asyncio.run(test()) |