File size: 3,167 Bytes
aa018e3 30ee88a 0fc97a4 30ee88a 0fc97a4 30ee88a 0fc97a4 30ee88a 0fc97a4 aa018e3 0fc97a4 aa018e3 0fc97a4 94e0eef 0fc97a4 aa018e3 0fc97a4 30ee88a aa018e3 30ee88a 0fc97a4 30ee88a 0fc97a4 30ee88a 0fc97a4 aa018e3 0fc97a4 aa018e3 0fc97a4 30ee88a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage
from typing import Literal
from langgraph.graph import START, END
from .state import AgentState, QueryAnalysis
from .prompts import *
from .tools import intialize_chroma_vectorstore
def router_node(state: AgentState, llm):
"""
Takes the query (and history). Decides the next step: vectordb, tools, or refuse.
"""
query = state["messages"][-1].content
rag_method_prompt = determine_rag_method_prompt()
rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)])
rag_method = rag_method_result.content.strip().upper()
state["rag_method"] = rag_method
return state
def routing_logic(state: AgentState) -> str:
rag_method = state["rag_method"]
if rag_method == "VECTORDB":
return "vectordb_node"
elif rag_method == "WEBSEARCH":
return "web_search_agent_node"
elif rag_method == "GENERAL":
return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
else:
# If the LLM violates the prompt and outputs an unknown word,
print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
return END
def vectordb_node(state: AgentState, vectorstore):
"""
Use vectordb to answer the query.
"""
context_docs = vectorstore.similarity_search(
query=state["messages"][-1].content,
k=5
)
context = "\n\n".join([doc.page_content for doc in context_docs])
state["context"] = context
return state
def web_search_agent_node(state: AgentState, llm):
"""
LLM agent that decides which web search tools to call.
This generates an AIMessage with tool_calls.
"""
messages = state["messages"]
# Add instruction to use tools
system_msg = SystemMessage(content="""You are a web search assistant.
Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query.
Call the appropriate tool with the query.""")
messages_with_system = [system_msg] + messages
# LLM with tools bound will generate AIMessage with tool_calls
response = llm.invoke(messages_with_system)
return {"messages": [response]}
def generate_node(state: AgentState, llm):
messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
context = state.get("context", [])
system_content = get_system_prompt()
# Extract web search results from ToolMessages if available
if not context:
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
# Web search results come as ToolMessage content
if msg.content:
context += f"\n\n{msg.content}"
if context:
system_content += f"\n\nRelevant Context:\n{context}"
messages_with_system = [SystemMessage(content=system_content)] + messages
response = llm.invoke(messages_with_system)
return {'messages': [response]}
if __name__ == "__main__":
pass |