Files changed (1) hide show
  1. agent.py +347 -0
agent.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cmath
2
+ import os
3
+ from typing import Dict, List, Sequence, TypedDict, cast
4
+
5
+ from dotenv import load_dotenv
6
+ from langchain.tools.retriever import create_retriever_tool
7
+ from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
8
+ from langchain_community.vectorstores import SupabaseVectorStore
9
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
10
+ from langchain_core.tools import tool
11
+ from langchain_google_genai import ChatGoogleGenerativeAI
12
+ from langchain_groq import ChatGroq
13
+ from langchain_huggingface import (
14
+ ChatHuggingFace,
15
+ HuggingFaceEmbeddings,
16
+ HuggingFaceEndpoint,
17
+ )
18
+ from langchain_tavily import TavilySearch
19
+ from langgraph.graph import END, START, MessagesState, StateGraph
20
+ from langgraph.prebuilt import ToolNode, tools_condition
21
+ from pydantic import BaseModel
22
+ from supabase.client import Client, create_client
23
+
24
+ # Load environment variables from .env file
25
+ load_dotenv()
26
+
27
+
28
+ class WebSearchInput(BaseModel):
29
+ query: str
30
+
31
+
32
+ class WikipediaSearchInput(BaseModel):
33
+ query: str
34
+
35
+
36
+ class ArxivSearchInput(BaseModel):
37
+ query: str
38
+
39
+
40
+ @tool
41
+ def search_web(query: str) -> str:
42
+ """Search the web using Tavily and return relevant results."""
43
+
44
+ """Search Tavily for a query and return maximum 3 results.
45
+
46
+ Args:
47
+ query: The search query."""
48
+ search_docs = TavilySearch(max_results=3).invoke({"query": query})
49
+ formatted_search_docs = "\n\n---\n\n".join(
50
+ [
51
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
52
+ for doc in search_docs
53
+ ]
54
+ )
55
+ return {"web_results": formatted_search_docs}
56
+
57
+
58
+ @tool
59
+ def search_wikipedia(query: str) -> str:
60
+ """Search Wikipedia using LangChain's loader and return the first document summary."""
61
+ try:
62
+ loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
63
+ docs = loader.load()
64
+ if not docs:
65
+ return {"error": f"No Wikipedia articles found for query: {query}"}
66
+ formatted_docs = "\n\n---\n\n".join(
67
+ [f"Wikipedia Article: {query}\n\n{doc.page_content}" for doc in docs]
68
+ )
69
+ return {"wiki_results": formatted_docs}
70
+ except Exception as e:
71
+ return {"error": f"Error searching Wikipedia: {str(e)}"}
72
+
73
+
74
+ @tool
75
+ def arxiv_search(query: str) -> str:
76
+ """Search Arxiv for a query and return maximum 3 result.
77
+ Args:
78
+ query: The search query."""
79
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
80
+ formatted_search_docs = "\n\n---\n\n".join(
81
+ [
82
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
83
+ for doc in search_docs
84
+ ]
85
+ )
86
+ return {"arxiv_results": formatted_search_docs}
87
+
88
+
89
+ @tool
90
+ def power(a: float, b: float) -> float:
91
+ """
92
+ Get the power of two numbers.
93
+ Args:
94
+ a (float): the first number
95
+ b (float): the second number
96
+ """
97
+ return a**b
98
+
99
+
100
+ @tool
101
+ def square_root(a: float) -> float | complex:
102
+ """
103
+ Get the square root of a number.
104
+ Args:
105
+ a (float): the number to get the square root of
106
+ """
107
+ if a >= 0:
108
+ return a**0.5
109
+ return cmath.sqrt(a)
110
+
111
+
112
+ @tool
113
+ def multiply(a: int, b: int) -> int:
114
+ """Multiply two numbers.
115
+ Args:
116
+ a: first int
117
+ b: second int
118
+ """
119
+ return a * b
120
+
121
+
122
+ @tool
123
+ def add(a: int, b: int) -> int:
124
+ """Add two numbers.
125
+ Args:
126
+ a: first int
127
+ b: second int
128
+ """
129
+ return a + b
130
+
131
+
132
+ @tool
133
+ def subtract(a: int, b: int) -> int:
134
+ """Subtract two numbers.
135
+ Args:
136
+ a: first int
137
+ b: second int
138
+ """
139
+ return a - b
140
+
141
+
142
+ @tool
143
+ def divide(a: float, b: float) -> float:
144
+ """
145
+ Divides two numbers.
146
+ Args:
147
+ a (float): the first float number
148
+ b (float): the second float number
149
+ """
150
+ if b == 0:
151
+ raise ValueError("Cannot divided by zero.")
152
+ return a / b
153
+
154
+
155
+ @tool
156
+ def modulus(a: int, b: int) -> int:
157
+ """Get the modulus of two numbers.
158
+ Args:
159
+ a: first int
160
+ b: second int
161
+ """
162
+ return a % b
163
+
164
+
165
+ # System prompt
166
+ system_prompt = SystemMessage(
167
+ content="""You are a helpful assistant tasked with answering questions using a set of tools.
168
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
169
+ FINAL ANSWER: [YOUR FINAL ANSWER].
170
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma.
171
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
172
+ )
173
+
174
+ supabase_url = os.environ.get("SUPABASE_URL")
175
+ supabase_service_key = os.environ.get("SUPABASE_SERVICE_KEY")
176
+ # build a retriever
177
+ embeddings = HuggingFaceEmbeddings(
178
+ model_name="sentence-transformers/all-mpnet-base-v2"
179
+ ) # dim=768
180
+ supabase: Client = create_client(supabase_url, supabase_service_key)
181
+ vector_store = SupabaseVectorStore(
182
+ client=supabase,
183
+ embedding=embeddings,
184
+ table_name="documents",
185
+ query_name="match_documents_langchain",
186
+ )
187
+ create_retriever_tool = create_retriever_tool(
188
+ retriever=vector_store.as_retriever(),
189
+ name="Question Search",
190
+ description="A tool to retrieve similar questions from a vector store.",
191
+ )
192
+
193
+ # Initialize tools
194
+ tools = [
195
+ search_wikipedia,
196
+ search_web,
197
+ arxiv_search,
198
+ power,
199
+ square_root,
200
+ multiply,
201
+ divide,
202
+ subtract,
203
+ add,
204
+ modulus,
205
+ ]
206
+
207
+
208
+ def build_agent_graph(provider: str = "groq"):
209
+ """Build the graph"""
210
+
211
+ # Initialize LLM class
212
+ try:
213
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
214
+ if provider == "groq":
215
+ # Groq https://console.groq.com/docs/models
216
+ chat_model = ChatGroq(
217
+ model="qwen-qwq-32b", temperature=0
218
+ ) # optional : qwen-qwq-32b gemma2-9b-it
219
+ elif provider == "gemini":
220
+ chat_model = ChatGoogleGenerativeAI(
221
+ model="gemini-2.5-pro",
222
+ temperature=1.0,
223
+ max_retries=2,
224
+ google_api_key=gemini_api_key,
225
+ )
226
+ elif provider == "huggingface":
227
+ llm = HuggingFaceEndpoint(
228
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
229
+ temperature=0,
230
+ )
231
+ chat_model = ChatHuggingFace(llm=llm, verbose=True)
232
+ else:
233
+ raise ValueError("Invalid provider.")
234
+ except Exception as e:
235
+ raise Exception(f"Failed to initialize LLM: {str(e)}")
236
+
237
+ llm_with_tools = chat_model.bind_tools(tools)
238
+
239
+ # Create nodes
240
+ def assistant(state: MessagesState):
241
+ """Assistant node"""
242
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
243
+
244
+ def retriever(state: MessagesState):
245
+ query = state["messages"][-1].content
246
+ results = vector_store.similarity_search(query, k=1)
247
+
248
+ if not results:
249
+ print(f"[retriever] No similar documents found for query: {query}")
250
+ return {
251
+ "messages": [
252
+ AIMessage(content="I couldn't find any similar content in memory.")
253
+ ]
254
+ }
255
+
256
+ similar_doc = results[0]
257
+ content = similar_doc.page_content
258
+
259
+ if "Final answer :" in content:
260
+ answer = content.split("Final answer :")[-1].strip()
261
+ else:
262
+ answer = content.strip()
263
+
264
+ return {"messages": [AIMessage(content=answer)]}
265
+
266
+ # Build graph
267
+ builder = StateGraph(MessagesState)
268
+ builder.add_node("retriever", retriever)
269
+ # builder.add_node("assistant", assistant)
270
+ # builder.add_node("tools", ToolNode(tools))
271
+ # builder.add_edge(START, "retriever")
272
+ # builder.add_edge("retriever", "assistant")
273
+ # builder.add_conditional_edges(
274
+ # "assistant",
275
+ # tools_condition,
276
+ # )
277
+ # builder.add_edge("tools", "assistant")
278
+
279
+ builder.set_entry_point("retriever")
280
+ builder.set_finish_point("retriever")
281
+
282
+ return builder.compile()
283
+
284
+
285
+ # Manual test function
286
+ def test_agent():
287
+ """Run a manual test of the agent"""
288
+ print("\n" + "=" * 50)
289
+ print("Starting Agent Test")
290
+ print("=" * 50)
291
+
292
+ # Check environment variables
293
+ if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
294
+ print("\nError: HUGGINGFACEHUB_API_TOKEN not set")
295
+ return
296
+ if not os.getenv("GEMINI_API_KEY"):
297
+ print("\nError: GEMINI_API_KEY not set")
298
+ return
299
+ if not os.getenv("TAVILY_API_KEY"):
300
+ print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable")
301
+
302
+ if not os.getenv("SUPABASE_URL"):
303
+ print("\nWarning: SUPABASE_URL not set - web search will be unavailable")
304
+
305
+ print("\nInitializing agent...")
306
+ try:
307
+ graph = build_agent_graph(provider="groq")
308
+ print("Agent initialized successfully")
309
+ except Exception as e:
310
+ print(f"Failed to initialize agent: {str(e)}")
311
+ return
312
+
313
+ # Test a single question
314
+ question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
315
+ print("\nTesting question:", question)
316
+ print("-" * 50)
317
+
318
+ try:
319
+ # Create messages state
320
+ messages = [HumanMessage(content=question)]
321
+
322
+ # Run agent
323
+ print("\nWaiting for response...")
324
+ result = graph.invoke({"messages": messages})
325
+
326
+ # Get answer
327
+ if result and "messages" in result and result["messages"]:
328
+
329
+ answer = result["messages"][-1].content
330
+ print("\nResponse received:")
331
+ print("-" * 20)
332
+ print(answer)
333
+ print("-" * 20)
334
+ else:
335
+ print("\nError: No response from agent")
336
+
337
+ except Exception as e:
338
+ print(f"\nError processing question: {str(e)}")
339
+
340
+ print("\n" + "=" * 50)
341
+ print("Test Complete")
342
+ print("=" * 50 + "\n")
343
+
344
+
345
+ # Run test if script is run directly
346
+ if __name__ == "__main__":
347
+ test_agent()