|
|
"""Main LangGraph Agent System Implementation""" |
|
|
import os |
|
|
from typing import Dict, Any, TypedDict, Literal |
|
|
from langchain_core.messages import BaseMessage, HumanMessage |
|
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
|
|
|
from src.agents.plan_node import plan_node |
|
|
from src.agents.router_node import router_node, should_route_to_agent |
|
|
from src.agents.retrieval_agent import retrieval_agent |
|
|
from src.agents.execution_agent import execution_agent |
|
|
from src.agents.critic_agent import critic_agent |
|
|
from src.agents.verification_node import verification_node, should_retry |
|
|
from src.memory import memory_manager |
|
|
from src.tracing import ( |
|
|
get_langfuse_callback_handler, |
|
|
update_trace_metadata, |
|
|
trace_agent_execution, |
|
|
flush_langfuse, |
|
|
) |
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
"""State schema for the agent system""" |
|
|
|
|
|
messages: list[BaseMessage] |
|
|
|
|
|
|
|
|
plan_complete: bool |
|
|
next_agent: str |
|
|
routing_decision: str |
|
|
routing_reason: str |
|
|
current_step: str |
|
|
|
|
|
|
|
|
agent_response: BaseMessage |
|
|
execution_result: str |
|
|
|
|
|
|
|
|
critic_assessment: str |
|
|
quality_pass: bool |
|
|
quality_score: int |
|
|
verification_status: str |
|
|
|
|
|
|
|
|
attempt_count: int |
|
|
final_answer: str |
|
|
|
|
|
|
|
|
def create_agent_graph() -> StateGraph: |
|
|
"""Create the LangGraph agent system""" |
|
|
|
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
workflow.add_node("plan", plan_node) |
|
|
workflow.add_node("router", router_node) |
|
|
workflow.add_node("retrieval", retrieval_agent) |
|
|
workflow.add_node("execution", execution_agent) |
|
|
workflow.add_node("critic", critic_agent) |
|
|
workflow.add_node("verification", verification_node) |
|
|
|
|
|
|
|
|
def fallback_node(state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Simple fallback that returns a basic response""" |
|
|
print("Fallback Node: Providing basic response") |
|
|
|
|
|
messages = state.get("messages", []) |
|
|
user_query = None |
|
|
|
|
|
for msg in reversed(messages): |
|
|
if msg.type == "human": |
|
|
user_query = msg.content |
|
|
break |
|
|
|
|
|
fallback_answer = "I apologize, but I was unable to provide a satisfactory answer to your question." |
|
|
if user_query: |
|
|
fallback_answer += f" Your question was: {user_query}" |
|
|
|
|
|
return { |
|
|
**state, |
|
|
"final_answer": fallback_answer, |
|
|
"verification_status": "fallback", |
|
|
"current_step": "complete" |
|
|
} |
|
|
|
|
|
workflow.add_node("fallback", fallback_node) |
|
|
|
|
|
|
|
|
workflow.set_entry_point("plan") |
|
|
|
|
|
|
|
|
workflow.add_edge("plan", "router") |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"router", |
|
|
should_route_to_agent, |
|
|
{ |
|
|
"retrieval": "retrieval", |
|
|
"execution": "execution", |
|
|
"critic": "critic" |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("retrieval", "critic") |
|
|
workflow.add_edge("execution", "critic") |
|
|
|
|
|
workflow.add_edge("critic", "verification") |
|
|
|
|
|
|
|
|
def verification_next(state: Dict[str, Any]) -> Literal["router", "fallback", END]: |
|
|
"""Determine next step after verification""" |
|
|
verification_status = state.get("verification_status", "") |
|
|
current_step = state.get("current_step", "") |
|
|
|
|
|
if current_step == "complete": |
|
|
return END |
|
|
elif verification_status == "failed" and state.get("attempt_count", 1) < 3: |
|
|
return "router" |
|
|
elif verification_status == "failed_max_attempts": |
|
|
return "fallback" |
|
|
else: |
|
|
return END |
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"verification", |
|
|
verification_next, |
|
|
{ |
|
|
"router": "router", |
|
|
"fallback": "fallback", |
|
|
END: END |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("fallback", END) |
|
|
|
|
|
return workflow |
|
|
|
|
|
|
|
|
def run_agent_system(query: str, user_id: str = None, session_id: str = None) -> str: |
|
|
""" |
|
|
Run the complete agent system with a user query |
|
|
|
|
|
Args: |
|
|
query: The user question |
|
|
user_id: Optional user identifier for tracing |
|
|
session_id: Optional session identifier for tracing |
|
|
|
|
|
Returns: |
|
|
The final formatted answer |
|
|
""" |
|
|
print(f"Agent System: Processing query: {query[:100]}...") |
|
|
|
|
|
|
|
|
with trace_agent_execution(name="user-request", user_id=user_id, session_id=session_id): |
|
|
try: |
|
|
|
|
|
update_trace_metadata( |
|
|
user_id=user_id, |
|
|
session_id=session_id, |
|
|
tags=["agent_system"], |
|
|
) |
|
|
|
|
|
|
|
|
workflow = create_agent_graph() |
|
|
|
|
|
|
|
|
checkpointer = memory_manager.get_checkpointer() |
|
|
if checkpointer: |
|
|
app = workflow.compile(checkpointer=checkpointer) |
|
|
else: |
|
|
app = workflow.compile() |
|
|
|
|
|
|
|
|
initial_state = { |
|
|
"messages": [HumanMessage(content=query)], |
|
|
"plan_complete": False, |
|
|
"next_agent": "", |
|
|
"routing_decision": "", |
|
|
"routing_reason": "", |
|
|
"current_step": "planning", |
|
|
"agent_response": None, |
|
|
"execution_result": "", |
|
|
"critic_assessment": "", |
|
|
"quality_pass": True, |
|
|
"quality_score": 7, |
|
|
"verification_status": "", |
|
|
"attempt_count": 1, |
|
|
"final_answer": "", |
|
|
} |
|
|
|
|
|
|
|
|
callback_handler = get_langfuse_callback_handler() |
|
|
config = { |
|
|
"configurable": {"thread_id": session_id or "default"}, |
|
|
} |
|
|
if callback_handler: |
|
|
config["callbacks"] = [callback_handler] |
|
|
|
|
|
|
|
|
print("Agent System: Executing workflow...") |
|
|
final_state = app.invoke(initial_state, config=config) |
|
|
|
|
|
|
|
|
final_answer = final_state.get("final_answer", "No answer generated") |
|
|
|
|
|
|
|
|
if memory_manager.should_ingest(query): |
|
|
memory_manager.ingest_qa_pair(query, final_answer) |
|
|
|
|
|
print(f"Agent System: Completed. Final answer: {final_answer[:100]}...") |
|
|
return final_answer |
|
|
except Exception as e: |
|
|
print(f"Agent System Error: {e}") |
|
|
return ( |
|
|
f"I apologize, but I encountered an error while processing your question: {e}" |
|
|
) |
|
|
finally: |
|
|
|
|
|
try: |
|
|
flush_langfuse() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["run_agent_system", "create_agent_graph", "AgentState"] |