|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage |
|
|
from typing import Literal |
|
|
from langgraph.graph import START, END |
|
|
|
|
|
from .state import AgentState, QueryAnalysis |
|
|
from .prompts import * |
|
|
from .tools import intialize_chroma_vectorstore |
|
|
|
|
|
|
|
|
def router_node(state: AgentState, llm): |
|
|
""" |
|
|
Takes the query (and history). Decides the next step: vectordb, tools, or refuse. |
|
|
""" |
|
|
query = state["messages"][-1].content |
|
|
rag_method_prompt = determine_rag_method_prompt() |
|
|
rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)]) |
|
|
rag_method = rag_method_result.content.strip().upper() |
|
|
state["rag_method"] = rag_method |
|
|
return state |
|
|
|
|
|
def routing_logic(state: AgentState) -> str: |
|
|
rag_method = state["rag_method"] |
|
|
if rag_method == "VECTORDB": |
|
|
return "vectordb_node" |
|
|
elif rag_method == "WEBSEARCH": |
|
|
return "web_search_agent_node" |
|
|
elif rag_method == "GENERAL": |
|
|
return "generate_node" |
|
|
else: |
|
|
|
|
|
print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.") |
|
|
return END |
|
|
|
|
|
def vectordb_node(state: AgentState, vectorstore): |
|
|
""" |
|
|
Use vectordb to answer the query. |
|
|
""" |
|
|
context_docs = vectorstore.similarity_search( |
|
|
query=state["messages"][-1].content, |
|
|
k=5 |
|
|
) |
|
|
context = "\n\n".join([doc.page_content for doc in context_docs]) |
|
|
state["context"] = context |
|
|
return state |
|
|
|
|
|
def web_search_agent_node(state: AgentState, llm): |
|
|
""" |
|
|
LLM agent that decides which web search tools to call. |
|
|
This generates an AIMessage with tool_calls. |
|
|
""" |
|
|
messages = state["messages"] |
|
|
|
|
|
|
|
|
system_msg = SystemMessage(content="""You are a web search assistant. |
|
|
Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query. |
|
|
Call the appropriate tool with the query.""") |
|
|
|
|
|
messages_with_system = [system_msg] + messages |
|
|
|
|
|
|
|
|
response = llm.invoke(messages_with_system) |
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
def generate_node(state: AgentState, llm): |
|
|
messages = state["messages"][-10:] |
|
|
context = state.get("context", []) |
|
|
|
|
|
system_content = get_system_prompt() |
|
|
|
|
|
|
|
|
if not context: |
|
|
for msg in reversed(messages): |
|
|
if isinstance(msg, ToolMessage): |
|
|
|
|
|
if msg.content: |
|
|
context += f"\n\n{msg.content}" |
|
|
|
|
|
if context: |
|
|
system_content += f"\n\nRelevant Context:\n{context}" |
|
|
|
|
|
messages_with_system = [SystemMessage(content=system_content)] + messages |
|
|
response = llm.invoke(messages_with_system) |
|
|
|
|
|
return {'messages': [response]} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |