File size: 3,167 Bytes
aa018e3
30ee88a
0fc97a4
30ee88a
 
 
0fc97a4
30ee88a
 
0fc97a4
30ee88a
0fc97a4
 
 
 
aa018e3
0fc97a4
 
 
 
aa018e3
0fc97a4
94e0eef
0fc97a4
 
aa018e3
0fc97a4
 
 
 
 
 
30ee88a
aa018e3
30ee88a
0fc97a4
30ee88a
0fc97a4
 
 
30ee88a
0fc97a4
 
 
 
aa018e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fc97a4
 
 
 
 
 
aa018e3
 
 
 
 
 
 
 
0fc97a4
 
 
 
 
 
 
 
30ee88a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage
from typing import Literal
from langgraph.graph import START, END

from .state import AgentState, QueryAnalysis
from .prompts import *
from .tools import intialize_chroma_vectorstore


def router_node(state: AgentState, llm):
    """
    Takes the query (and history). Decides the next step: vectordb, tools, or refuse.
    """
    query = state["messages"][-1].content
    rag_method_prompt = determine_rag_method_prompt()
    rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)])
    rag_method = rag_method_result.content.strip().upper()
    state["rag_method"] = rag_method 
    return state

def routing_logic(state: AgentState) -> str:
    rag_method = state["rag_method"]
    if rag_method == "VECTORDB":
        return "vectordb_node"
    elif rag_method == "WEBSEARCH":
        return "web_search_agent_node"
    elif rag_method == "GENERAL":
        return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
    else:
        # If the LLM violates the prompt and outputs an unknown word,
        print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
        return END

def vectordb_node(state: AgentState, vectorstore):
    """
    Use vectordb to answer the query.
    """
    context_docs = vectorstore.similarity_search(
        query=state["messages"][-1].content,
        k=5
    )
    context = "\n\n".join([doc.page_content for doc in context_docs])
    state["context"] = context
    return state

def web_search_agent_node(state: AgentState, llm):
    """
    LLM agent that decides which web search tools to call.
    This generates an AIMessage with tool_calls.
    """
    messages = state["messages"]
    
    # Add instruction to use tools
    system_msg = SystemMessage(content="""You are a web search assistant. 
Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query.
Call the appropriate tool with the query.""")
    
    messages_with_system = [system_msg] + messages
    
    # LLM with tools bound will generate AIMessage with tool_calls
    response = llm.invoke(messages_with_system)
    
    return {"messages": [response]}

def generate_node(state: AgentState, llm):
    messages = state["messages"][-10:]  # Limit to last 10 messages to handle token limit
    context = state.get("context", [])
    
    system_content = get_system_prompt()

    # Extract web search results from ToolMessages if available
    if not context:
        for msg in reversed(messages):
            if isinstance(msg, ToolMessage):
                # Web search results come as ToolMessage content
                if msg.content:
                    context += f"\n\n{msg.content}"

    if context:
        system_content += f"\n\nRelevant Context:\n{context}"
    
    messages_with_system = [SystemMessage(content=system_content)] + messages
    response = llm.invoke(messages_with_system)

    return {'messages': [response]}
    
if __name__ == "__main__":
    pass