tbuyuktanir commited on
Commit
eb8329a
·
verified ·
1 Parent(s): 81917a3

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +199 -0
agent.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_community.tools import DuckDuckGoSearchResults
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+ from langchain_openai import ChatOpenAI
19
+
20
+ load_dotenv()
21
+
22
+ @tool
23
+ def add(x: int, y: int) -> int:
24
+ """Adds two numbers.
25
+ :arg x: The first number.
26
+ :arg y: The second number.
27
+ """
28
+ return x + y
29
+
30
+ @tool
31
+ def subtract(x: int, y: int) -> int:
32
+ """Subtracts two numbers.
33
+ :arg x: The first number.
34
+ :arg y: The second number.
35
+ """
36
+ return x - y
37
+ @tool
38
+ def multiply(x: int, y: int) -> int:
39
+ """Multiplies two numbers.
40
+ :arg x: The first number.
41
+ :arg y: The second number.
42
+ """
43
+ return x * y
44
+
45
+ @tool
46
+ def divide(x: int, y: int) -> float:
47
+ """Divides two numbers.
48
+ :arg x: The first number.
49
+ :arg y: The second number.
50
+ :raises ValueError: If y is zero.
51
+ """
52
+ if y == 0:
53
+ raise ValueError("Cannot divide by zero.")
54
+ return x / y
55
+
56
+ @tool
57
+ def modulus(x: int, y: int) -> int:
58
+ """Calculates the modulus of two numbers.
59
+ :arg x: The first number.
60
+ :arg y: The second number.
61
+ :raises ValueError: If y is zero.
62
+ """
63
+ return x % y
64
+ @tool
65
+ def wiki_search(query: str) -> str:
66
+ """Searches Wikipedia for the given query and returns the top results.
67
+ :arg query: The search query.
68
+ """
69
+ loader = WikipediaLoader(query=query, load_max_docs=2)
70
+ docs = loader.load()
71
+ formatted_search_docs = "\n\n---\n\n".join(
72
+ [
73
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
74
+ for doc in docs
75
+ ])
76
+ return {"wiki_results": formatted_search_docs}
77
+
78
+ @tool
79
+ def web_search(query: str) -> str:
80
+ """Searches the web for the given query using Tavily and returns the top results.
81
+ :arg query: The search query.
82
+ """
83
+ tavily_search = DuckDuckGoSearchResults(query=query, num_results=3)
84
+ print(f"Running web search for query(DuckDuckGo): {query}")
85
+ results = tavily_search.run()
86
+ formatted_results = "\n\n---\n\n".join(
87
+ [f'<Document source="{result["source"]}" page="{result.get("page", "")}"/>\n{result["content"]}\n</Document>'
88
+ for result in results])
89
+ return {"web_results": formatted_results}
90
+
91
+
92
+ @tool
93
+ def arvix_search(query: str) -> str:
94
+ """Searches Arxiv for the given query and returns the top results.
95
+ :arg query: The search query.
96
+ """
97
+ loader = ArxivLoader(query=query, load_max_docs=3)
98
+ docs = loader.load()
99
+ formatted_search_docs = "\n\n---\n\n".join(
100
+ [
101
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
102
+ for doc in docs
103
+ ])
104
+ return {"arxiv_results": formatted_search_docs}
105
+
106
+
107
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
+ system_prompt = f.read()
109
+
110
+ sys_msg = SystemMessage(content=system_prompt)
111
+
112
+ # build a retriever
113
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
114
+ supabase: Client = create_client(
115
+ os.environ.get("SUPABASE_URL"),
116
+ os.environ.get("SUPABASE_SERVICE_KEY"))
117
+ print("Supabase client created.")
118
+ vector_store = SupabaseVectorStore(
119
+ client=supabase,
120
+ embedding= embeddings,
121
+ table_name="documents",
122
+ query_name="match_documents_langchain",
123
+ )
124
+ print("Vector store initialized with Supabase.")
125
+ create_retriever_tool = create_retriever_tool(
126
+ retriever=vector_store.as_retriever(),
127
+ name="Question Search",
128
+ description="A tool to retrieve similar questions from a vector store.",
129
+ )
130
+ print("Retriever tool created.")
131
+ tools = [
132
+ add,
133
+ subtract,
134
+ multiply,
135
+ divide,
136
+ modulus,
137
+ wiki_search,
138
+ web_search,
139
+ arvix_search,
140
+ ]
141
+
142
+ def build_graph(provider: str = "huggingface") -> StateGraph:
143
+ if provider == "google":
144
+ # Google Gemini
145
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
146
+ elif provider == "groq":
147
+ # Groq https://console.groq.com/docs/models
148
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
149
+ elif provider=="openai":
150
+ # OpenAI
151
+ llm = ChatOpenAI(model="gpt-4o", temperature=0)
152
+ elif provider == "huggingface":
153
+ llm = ChatHuggingFace(
154
+ llm=HuggingFaceEndpoint(
155
+ repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"
156
+ ),
157
+ )
158
+ else:
159
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
160
+ # Bind tools to LLM
161
+ llm_with_tools = llm.bind_tools(tools)
162
+
163
+ # Node
164
+ def assistant(state: MessagesState):
165
+ """Assistant node"""
166
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
167
+
168
+ def retriever(state: MessagesState):
169
+ """Retriever node"""
170
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
171
+ example_msg = HumanMessage(
172
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
173
+ )
174
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
175
+
176
+ builder = StateGraph(MessagesState)
177
+ builder.add_node("retriever", retriever)
178
+ builder.add_node("assistant", assistant)
179
+ builder.add_node("tools", ToolNode(tools))
180
+ builder.add_edge(START, "retriever")
181
+ builder.add_edge("retriever", "assistant")
182
+ builder.add_conditional_edges(
183
+ "assistant",
184
+ tools_condition,
185
+ )
186
+ builder.add_edge("tools", "assistant")
187
+
188
+ # Compile graph
189
+ return builder.compile()
190
+
191
+ if __name__ == "__main__":
192
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
193
+ # Build the graph
194
+ graph = build_graph(provider="openai")
195
+ # Run the graph
196
+ messages = [HumanMessage(content=question)]
197
+ messages = graph.invoke({"messages": messages})
198
+ for m in messages["messages"]:
199
+ m.pretty_print()