rag_agent / agent /nodes.py
Cheh Kit Hong
changed rag method flags
94e0eef
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage
from typing import Literal
from langgraph.graph import START, END
from .state import AgentState, QueryAnalysis
from .prompts import *
from .tools import intialize_chroma_vectorstore
def router_node(state: AgentState, llm):
"""
Takes the query (and history). Decides the next step: vectordb, tools, or refuse.
"""
query = state["messages"][-1].content
rag_method_prompt = determine_rag_method_prompt()
rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)])
rag_method = rag_method_result.content.strip().upper()
state["rag_method"] = rag_method
return state
def routing_logic(state: AgentState) -> str:
rag_method = state["rag_method"]
if rag_method == "VECTORDB":
return "vectordb_node"
elif rag_method == "WEBSEARCH":
return "web_search_agent_node"
elif rag_method == "GENERAL":
return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
else:
# If the LLM violates the prompt and outputs an unknown word,
print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
return END
def vectordb_node(state: AgentState, vectorstore):
"""
Use vectordb to answer the query.
"""
context_docs = vectorstore.similarity_search(
query=state["messages"][-1].content,
k=5
)
context = "\n\n".join([doc.page_content for doc in context_docs])
state["context"] = context
return state
def web_search_agent_node(state: AgentState, llm):
"""
LLM agent that decides which web search tools to call.
This generates an AIMessage with tool_calls.
"""
messages = state["messages"]
# Add instruction to use tools
system_msg = SystemMessage(content="""You are a web search assistant.
Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query.
Call the appropriate tool with the query.""")
messages_with_system = [system_msg] + messages
# LLM with tools bound will generate AIMessage with tool_calls
response = llm.invoke(messages_with_system)
return {"messages": [response]}
def generate_node(state: AgentState, llm):
messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
context = state.get("context", [])
system_content = get_system_prompt()
# Extract web search results from ToolMessages if available
if not context:
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
# Web search results come as ToolMessage content
if msg.content:
context += f"\n\n{msg.content}"
if context:
system_content += f"\n\nRelevant Context:\n{context}"
messages_with_system = [SystemMessage(content=system_content)] + messages
response = llm.invoke(messages_with_system)
return {'messages': [response]}
if __name__ == "__main__":
pass