Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, MessagesState,START,END | |
| from typing import Literal | |
| from langchain_core.tools import tool | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_groq import ChatGroq | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| memory = MemorySaver() | |
| class chatbot: | |
| def __init__(self): | |
| self.llm = ChatGroq(model_name="Gemma2-9b-It") | |
| self.memory = memory | |
| self.call_tool() | |
| def call_tool(self): | |
| tool = TavilySearchResults(max_results=2) | |
| self.tool_node = ToolNode(tools=[tool]) | |
| self.llm_with_tool = self.llm.bind_tools([tool]) | |
| def call_model(self, state: MessagesState): | |
| config = {"configurable": {"thread_id": "1"}} | |
| messages = state['messages'] | |
| response = self.llm_with_tool.invoke(messages, config=config) | |
| return {"messages": [response]} | |
| def router_function(self, state: MessagesState) -> Literal["tools", END]: | |
| messages = state['messages'] | |
| last_message = messages[-1] | |
| if last_message.tool_calls: | |
| return "tools" | |
| return END | |
| def __call__(self): | |
| workflow = StateGraph(MessagesState) | |
| workflow.add_node("agent", self.call_model) | |
| workflow.add_node("tools", self.tool_node) | |
| workflow.add_edge(START, "agent") | |
| workflow.add_conditional_edges("agent", self.router_function, {"tools": "tools", END: END}) | |
| workflow.add_edge("tools", "agent") | |
| self.app = workflow.compile(checkpointer=self.memory) | |
| return self.app | |
| if __name__ == "__main__": | |
| mybot = chatbot() | |
| workflow = mybot() | |
| response = workflow.invoke({"messages": ["Who is the current prime minister of the USA?"]}) | |
| print(response['messages'][-1].content) | |