jtan4albany commited on
Commit
2b42d51
·
verified ·
1 Parent(s): bbf9d7e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +80 -0
agent.py CHANGED
@@ -7,6 +7,23 @@ from pint import UnitRegistry
7
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
8
  from langchain_community.chat_models import ChatOllama
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class MathTool(Tool):
11
  name = "math_tool"
12
  description = "Safely evaluates math expressions using symbolic math."
@@ -183,3 +200,66 @@ tools = [
183
  CodeExecutionTool(),
184
  UnitConversionTool(),
185
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
8
  from langchain_community.chat_models import ChatOllama
9
 
10
+ import os
11
+ from dotenv import load_dotenv
12
+ from langgraph.graph import START, StateGraph, MessagesState
13
+ from langgraph.prebuilt import tools_condition
14
+ from langgraph.prebuilt import ToolNode
15
+ from langchain_google_genai import ChatGoogleGenerativeAI
16
+ from langchain_groq import ChatGroq
17
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
18
+ from langchain_community.tools.tavily_search import TavilySearchResults
19
+ from langchain_community.document_loaders import WikipediaLoader
20
+ from langchain_community.document_loaders import ArxivLoader
21
+ from langchain_community.vectorstores import SupabaseVectorStore
22
+ from langchain_core.messages import SystemMessage, HumanMessage
23
+ from langchain_core.tools import tool
24
+ from langchain.tools.retriever import create_retriever_tool
25
+ from supabase.client import Client, create_client
26
+
27
  class MathTool(Tool):
28
  name = "math_tool"
29
  description = "Safely evaluates math expressions using symbolic math."
 
200
  CodeExecutionTool(),
201
  UnitConversionTool(),
202
  ]
203
+
204
+
205
+ # Build graph function
206
+ def build_graph(provider: str = "groq"):
207
+ """Build the graph"""
208
+ # Load environment variables from .env file
209
+ if provider == "google":
210
+ # Google Gemini
211
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
212
+ elif provider == "groq":
213
+ # Groq https://console.groq.com/docs/models
214
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
215
+ elif provider == "huggingface":
216
+ # TODO: Add huggingface endpoint
217
+ llm = ChatHuggingFace(
218
+ llm=HuggingFaceEndpoint(
219
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
220
+ temperature=0,
221
+ ),
222
+ )
223
+ else:
224
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
225
+ # Bind tools to LLM
226
+ llm_with_tools = llm.bind_tools(tools)
227
+
228
+ # Node
229
+ def assistant(state: MessagesState):
230
+ """Assistant node"""
231
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
232
+
233
+ def retriever(state: MessagesState):
234
+ """Retriever node"""
235
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
236
+ example_msg = HumanMessage(
237
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
238
+ )
239
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
240
+
241
+ builder = StateGraph(MessagesState)
242
+ builder.add_node("retriever", retriever)
243
+ builder.add_node("assistant", assistant)
244
+ builder.add_node("tools", ToolNode(tools))
245
+ builder.add_edge(START, "retriever")
246
+ builder.add_edge("retriever", "assistant")
247
+ builder.add_conditional_edges(
248
+ "assistant",
249
+ tools_condition,
250
+ )
251
+ builder.add_edge("tools", "assistant")
252
+
253
+ # Compile graph
254
+ return builder.compile()
255
+
256
+ # test
257
+ if __name__ == "__main__":
258
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
259
+ # Build the graph
260
+ graph = build_graph(provider="groq")
261
+ # Run the graph
262
+ messages = [HumanMessage(content=question)]
263
+ messages = graph.invoke({"messages": messages})
264
+ for m in messages["messages"]:
265
+ m.pretty_print()