Spaces:
Sleeping
Sleeping
File size: 1,929 Bytes
1d7eb48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from langgraph.graph import StateGraph, MessagesState,START,END
from typing import Literal
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from dotenv import load_dotenv
load_dotenv()
memory = MemorySaver()
class chatbot:
def __init__(self):
self.llm = ChatGroq(model_name="Gemma2-9b-It")
self.memory = memory
self.call_tool()
def call_tool(self):
tool = TavilySearchResults(max_results=2)
self.tool_node = ToolNode(tools=[tool])
self.llm_with_tool = self.llm.bind_tools([tool])
def call_model(self, state: MessagesState):
config = {"configurable": {"thread_id": "1"}}
messages = state['messages']
response = self.llm_with_tool.invoke(messages, config=config)
return {"messages": [response]}
def router_function(self, state: MessagesState) -> Literal["tools", END]:
messages = state['messages']
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return END
def __call__(self):
workflow = StateGraph(MessagesState)
workflow.add_node("agent", self.call_model)
workflow.add_node("tools", self.tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", self.router_function, {"tools": "tools", END: END})
workflow.add_edge("tools", "agent")
self.app = workflow.compile(checkpointer=self.memory)
return self.app
if __name__ == "__main__":
mybot = chatbot()
workflow = mybot()
response = workflow.invoke({"messages": ["Who is the current prime minister of the USA?"]})
print(response['messages'][-1].content)
|