from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage from typing import Literal from langgraph.graph import START, END from .state import AgentState, QueryAnalysis from .prompts import * from .tools import intialize_chroma_vectorstore def router_node(state: AgentState, llm): """ Takes the query (and history). Decides the next step: vectordb, tools, or refuse. """ query = state["messages"][-1].content rag_method_prompt = determine_rag_method_prompt() rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)]) rag_method = rag_method_result.content.strip().upper() state["rag_method"] = rag_method return state def routing_logic(state: AgentState) -> str: rag_method = state["rag_method"] if rag_method == "VECTORDB": return "vectordb_node" elif rag_method == "WEBSEARCH": return "web_search_agent_node" elif rag_method == "GENERAL": return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch else: # If the LLM violates the prompt and outputs an unknown word, print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.") return END def vectordb_node(state: AgentState, vectorstore): """ Use vectordb to answer the query. """ context_docs = vectorstore.similarity_search( query=state["messages"][-1].content, k=5 ) context = "\n\n".join([doc.page_content for doc in context_docs]) state["context"] = context return state def web_search_agent_node(state: AgentState, llm): """ LLM agent that decides which web search tools to call. This generates an AIMessage with tool_calls. """ messages = state["messages"] # Add instruction to use tools system_msg = SystemMessage(content="""You are a web search assistant. Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query. Call the appropriate tool with the query.""") messages_with_system = [system_msg] + messages # LLM with tools bound will generate AIMessage with tool_calls response = llm.invoke(messages_with_system) return {"messages": [response]} def generate_node(state: AgentState, llm): messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit context = state.get("context", []) system_content = get_system_prompt() # Extract web search results from ToolMessages if available if not context: for msg in reversed(messages): if isinstance(msg, ToolMessage): # Web search results come as ToolMessage content if msg.content: context += f"\n\n{msg.content}" if context: system_content += f"\n\nRelevant Context:\n{context}" messages_with_system = [SystemMessage(content=system_content)] + messages response = llm.invoke(messages_with_system) return {'messages': [response]} if __name__ == "__main__": pass