File size: 1,929 Bytes
1d7eb48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from langgraph.graph import StateGraph, MessagesState,START,END
from typing import Literal
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from dotenv import load_dotenv

load_dotenv()

memory = MemorySaver()

class chatbot:
    def __init__(self):
        self.llm = ChatGroq(model_name="Gemma2-9b-It")
        self.memory = memory  
        self.call_tool()

    def call_tool(self):
        tool = TavilySearchResults(max_results=2)
        self.tool_node = ToolNode(tools=[tool])
        self.llm_with_tool = self.llm.bind_tools([tool])

    def call_model(self, state: MessagesState):
        config = {"configurable": {"thread_id": "1"}}
        messages = state['messages']
        response = self.llm_with_tool.invoke(messages, config=config)
        return {"messages": [response]}

    def router_function(self, state: MessagesState) -> Literal["tools", END]:
        messages = state['messages']
        last_message = messages[-1]
        if last_message.tool_calls:
            return "tools"
        return END

    def __call__(self):
        workflow = StateGraph(MessagesState)
        workflow.add_node("agent", self.call_model)
        workflow.add_node("tools", self.tool_node)
        workflow.add_edge(START, "agent")
        workflow.add_conditional_edges("agent", self.router_function, {"tools": "tools", END: END})
        workflow.add_edge("tools", "agent")

        self.app = workflow.compile(checkpointer=self.memory)
        return self.app

if __name__ == "__main__":
    mybot = chatbot()
    workflow = mybot()

    response = workflow.invoke({"messages": ["Who is the current prime minister of the USA?"]})
    print(response['messages'][-1].content)